sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,15 @@ from dataclasses import dataclass
|
|
5
5
|
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
|
6
6
|
|
7
7
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
8
|
-
from sglang.srt.layers.moe import
|
9
|
-
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
8
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
10
9
|
BaseDispatcher,
|
11
10
|
BaseDispatcherConfig,
|
11
|
+
CombineInput,
|
12
|
+
CombineInputFormat,
|
12
13
|
DispatchOutput,
|
13
14
|
DispatchOutputFormat,
|
14
15
|
)
|
16
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, get_deepep_config, is_tbo_enabled
|
15
17
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
16
18
|
from sglang.srt.utils import (
|
17
19
|
get_bool_env_var,
|
@@ -40,11 +42,6 @@ from enum import Enum, IntEnum, auto
|
|
40
42
|
import torch
|
41
43
|
import torch.distributed as dist
|
42
44
|
|
43
|
-
from sglang.srt.layers.moe.ep_moe.kernels import (
|
44
|
-
deepep_permute_triton_kernel,
|
45
|
-
deepep_post_reorder_triton_kernel,
|
46
|
-
deepep_run_moe_deep_preprocess,
|
47
|
-
)
|
48
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
46
|
|
50
47
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
@@ -56,6 +53,7 @@ class DeepEPNormalOutput(NamedTuple):
|
|
56
53
|
"""DeepEP normal dispatch output."""
|
57
54
|
|
58
55
|
hidden_states: torch.Tensor | Tuple[torch.Tensor, torch.Tensor]
|
56
|
+
# hidden_states_scale
|
59
57
|
topk_idx: torch.Tensor
|
60
58
|
topk_weights: torch.Tensor
|
61
59
|
num_recv_tokens_per_expert: List[int]
|
@@ -79,24 +77,32 @@ class DeepEPLLOutput(NamedTuple):
|
|
79
77
|
return DispatchOutputFormat.DEEPEP_LL
|
80
78
|
|
81
79
|
|
82
|
-
|
83
|
-
|
80
|
+
assert isinstance(DeepEPNormalOutput, DispatchOutput)
|
81
|
+
assert isinstance(DeepEPLLOutput, DispatchOutput)
|
84
82
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
expected_m: int
|
83
|
+
|
84
|
+
class DeepEPNormalCombineInput(NamedTuple):
|
85
|
+
"""DeepEP normal combine input."""
|
86
|
+
|
87
|
+
pass
|
91
88
|
|
92
89
|
@property
|
93
|
-
def format(self) ->
|
94
|
-
return
|
90
|
+
def format(self) -> CombineInputFormat:
|
91
|
+
return CombineInputFormat.DEEPEP_NORMAL
|
95
92
|
|
96
93
|
|
97
|
-
|
98
|
-
|
99
|
-
|
94
|
+
class DeepEPLLCombineInput(NamedTuple):
|
95
|
+
"""DeepEP low latency combine input."""
|
96
|
+
|
97
|
+
pass
|
98
|
+
|
99
|
+
@property
|
100
|
+
def format(self) -> CombineInputFormat:
|
101
|
+
return CombineInputFormat.DEEPEP_LL
|
102
|
+
|
103
|
+
|
104
|
+
assert isinstance(DeepEPNormalCombineInput, CombineInput)
|
105
|
+
assert isinstance(DeepEPLLCombineInput, CombineInput)
|
100
106
|
|
101
107
|
|
102
108
|
class DeepEPDispatchMode(IntEnum):
|
@@ -272,6 +278,9 @@ class _DeepEPDispatcherImplBase:
|
|
272
278
|
self.num_max_dispatch_tokens_per_rank = get_int_env_var(
|
273
279
|
"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK", 128
|
274
280
|
)
|
281
|
+
# DeepEP internode_ll dispatch uses FINISHED_SUM_TAG=1024
|
282
|
+
# and the logic requires num-tokens-sent-from-one-rank-to-another-rank less than it
|
283
|
+
assert self.num_max_dispatch_tokens_per_rank <= 1024
|
275
284
|
|
276
285
|
self.handle = None
|
277
286
|
|
@@ -409,7 +418,11 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
409
418
|
topk_idx: torch.Tensor,
|
410
419
|
topk_weights: torch.Tensor,
|
411
420
|
):
|
412
|
-
|
421
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
422
|
+
deepep_post_reorder_triton_kernel,
|
423
|
+
)
|
424
|
+
|
425
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
413
426
|
output = hidden_states
|
414
427
|
else:
|
415
428
|
if hidden_states.shape[0] > 0:
|
@@ -495,7 +508,8 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
495
508
|
hidden_states, masked_m, event, hook = self._dispatch_core(
|
496
509
|
hidden_states,
|
497
510
|
topk_idx,
|
498
|
-
|
511
|
+
# TODO(shuw): pending https://github.com/deepseek-ai/DeepEP/pull/341
|
512
|
+
use_fp8=not get_bool_env_var("SGLANG_DEEPEP_BF16_DISPATCH"),
|
499
513
|
)
|
500
514
|
return (
|
501
515
|
hidden_states,
|
@@ -523,23 +537,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
523
537
|
masked_m
|
524
538
|
)
|
525
539
|
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
expected_m,
|
534
|
-
)
|
535
|
-
else:
|
536
|
-
deepep_output = DeepEPLLOutput(
|
537
|
-
hidden_states,
|
538
|
-
topk_idx,
|
539
|
-
topk_weights,
|
540
|
-
masked_m,
|
541
|
-
expected_m,
|
542
|
-
)
|
540
|
+
deepep_output = DeepEPLLOutput(
|
541
|
+
hidden_states,
|
542
|
+
topk_idx,
|
543
|
+
topk_weights,
|
544
|
+
masked_m,
|
545
|
+
expected_m,
|
546
|
+
)
|
543
547
|
return deepep_output
|
544
548
|
|
545
549
|
def _dispatch_core(
|
@@ -1,19 +1,61 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import NamedTuple
|
3
|
+
from typing import TYPE_CHECKING, NamedTuple
|
4
4
|
|
5
|
-
|
5
|
+
import torch
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.token_dispatcher.base import (
|
8
|
+
BaseDispatcher,
|
9
|
+
CombineInput,
|
10
|
+
CombineInputFormat,
|
6
11
|
DispatchOutput,
|
7
12
|
DispatchOutputFormat,
|
8
13
|
)
|
9
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
17
|
+
|
10
18
|
|
11
19
|
class StandardDispatchOutput(NamedTuple):
|
12
20
|
"""Standard dispatch output."""
|
13
21
|
|
22
|
+
hidden_states: torch.Tensor
|
23
|
+
topk_output: TopKOutput
|
24
|
+
|
14
25
|
@property
|
15
26
|
def format(self) -> DispatchOutputFormat:
|
16
27
|
return DispatchOutputFormat.STANDARD
|
17
28
|
|
18
29
|
|
19
30
|
assert isinstance(StandardDispatchOutput, DispatchOutput)
|
31
|
+
|
32
|
+
|
33
|
+
class StandardCombineInput(NamedTuple):
|
34
|
+
"""Standard combine input."""
|
35
|
+
|
36
|
+
hidden_states: torch.Tensor
|
37
|
+
|
38
|
+
@property
|
39
|
+
def format(self) -> CombineInputFormat:
|
40
|
+
return CombineInputFormat.STANDARD
|
41
|
+
|
42
|
+
|
43
|
+
assert isinstance(StandardCombineInput, CombineInput)
|
44
|
+
|
45
|
+
|
46
|
+
class StandardDispatcher(BaseDispatcher):
|
47
|
+
|
48
|
+
def dispatch(
|
49
|
+
self, hidden_states: torch.Tensor, topk_output: TopKOutput
|
50
|
+
) -> DispatchOutput:
|
51
|
+
return StandardDispatchOutput(
|
52
|
+
hidden_states=hidden_states, topk_output=topk_output
|
53
|
+
)
|
54
|
+
|
55
|
+
def combine(self, combine_input: CombineInput) -> torch.Tensor:
|
56
|
+
if isinstance(combine_input, StandardCombineInput):
|
57
|
+
return combine_input.hidden_states
|
58
|
+
else:
|
59
|
+
# TODO: this branch should be removed in the future
|
60
|
+
assert isinstance(combine_input, torch.Tensor)
|
61
|
+
return combine_input
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -19,6 +19,7 @@ import math
|
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from enum import Enum, auto
|
21
21
|
from typing import (
|
22
|
+
TYPE_CHECKING,
|
22
23
|
Callable,
|
23
24
|
NamedTuple,
|
24
25
|
Optional,
|
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
|
|
51
52
|
is_npu,
|
52
53
|
)
|
53
54
|
|
55
|
+
if TYPE_CHECKING:
|
56
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
57
|
+
|
54
58
|
try:
|
55
59
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
56
60
|
except ImportError:
|
@@ -94,6 +98,7 @@ class TopKConfig:
|
|
94
98
|
torch_native: bool = False
|
95
99
|
routed_scaling_factor: Optional[float] = None
|
96
100
|
apply_routed_scaling_factor_on_output: bool = False
|
101
|
+
output_format: Optional[TopKOutputFormat] = None
|
97
102
|
|
98
103
|
|
99
104
|
# -------------------------------- TopKOutput ---------------------------------------
|
@@ -196,9 +201,10 @@ class TopK(CustomOp):
|
|
196
201
|
custom_routing_function: Optional[Callable] = None,
|
197
202
|
scoring_func: str = "softmax",
|
198
203
|
correction_bias: Optional[torch.Tensor] = None,
|
204
|
+
quant_config: Optional[QuantizationConfig] = None,
|
199
205
|
routed_scaling_factor: Optional[float] = None,
|
200
206
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
201
|
-
|
207
|
+
output_format: Optional[TopKOutputFormat] = None,
|
202
208
|
):
|
203
209
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
204
210
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
@@ -218,11 +224,9 @@ class TopK(CustomOp):
|
|
218
224
|
correction_bias=correction_bias,
|
219
225
|
routed_scaling_factor=routed_scaling_factor,
|
220
226
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
227
|
+
output_format=output_format,
|
221
228
|
)
|
222
229
|
|
223
|
-
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
|
224
|
-
self.force_topk = force_topk
|
225
|
-
|
226
230
|
def forward_native(
|
227
231
|
self,
|
228
232
|
hidden_states: torch.Tensor,
|
@@ -248,7 +252,19 @@ class TopK(CustomOp):
|
|
248
252
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
249
253
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
250
254
|
) -> TopKOutput:
|
251
|
-
if self.
|
255
|
+
if self.topk_config.output_format is not None:
|
256
|
+
output_format = self.topk_config.output_format
|
257
|
+
elif get_moe_runner_backend().is_triton_kernel():
|
258
|
+
output_format = TopKOutputFormat.TRITON_KERNEL
|
259
|
+
elif (
|
260
|
+
should_use_flashinfer_trtllm_moe()
|
261
|
+
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
262
|
+
):
|
263
|
+
output_format = TopKOutputFormat.BYPASSED
|
264
|
+
else:
|
265
|
+
output_format = TopKOutputFormat.STANDARD
|
266
|
+
|
267
|
+
if output_format == TopKOutputFormat.TRITON_KERNEL:
|
252
268
|
# renormalize=True is equivalent to sm_first=False
|
253
269
|
routing_data, gather_idx, scatter_idx = routing(
|
254
270
|
router_logits,
|
@@ -256,10 +272,7 @@ class TopK(CustomOp):
|
|
256
272
|
sm_first=not self.topk_config.renormalize,
|
257
273
|
)
|
258
274
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
259
|
-
elif
|
260
|
-
should_use_flashinfer_trtllm_moe()
|
261
|
-
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
262
|
-
):
|
275
|
+
elif output_format == TopKOutputFormat.BYPASSED:
|
263
276
|
return BypassedTopKOutput(
|
264
277
|
hidden_states=hidden_states,
|
265
278
|
router_logits=router_logits,
|
@@ -330,6 +343,14 @@ class TopK(CustomOp):
|
|
330
343
|
)
|
331
344
|
topk_weights = topk_weights / topk_weights_sum
|
332
345
|
|
346
|
+
if expert_location_dispatch_info is not None:
|
347
|
+
topk_ids = topk_ids_logical_to_physical(
|
348
|
+
topk_ids, expert_location_dispatch_info
|
349
|
+
)
|
350
|
+
get_global_expert_distribution_recorder().on_select_experts(
|
351
|
+
topk_ids=topk_ids
|
352
|
+
)
|
353
|
+
|
333
354
|
return StandardTopKOutput(topk_weights, topk_ids, _)
|
334
355
|
else:
|
335
356
|
self.topk_config.torch_native = True
|
sglang/srt/layers/moe/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import importlib.util
|
4
|
+
import logging
|
4
5
|
from enum import Enum
|
5
6
|
from functools import lru_cache
|
6
7
|
from typing import TYPE_CHECKING, Optional
|
@@ -12,11 +13,12 @@ from sglang.srt.layers.dp_attention import (
|
|
12
13
|
get_attention_dp_size,
|
13
14
|
is_dp_attention_enabled,
|
14
15
|
)
|
15
|
-
from sglang.srt.utils import logger
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.server_args import ServerArgs
|
19
19
|
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
20
22
|
|
21
23
|
class MoeA2ABackend(Enum):
|
22
24
|
|
@@ -44,9 +46,10 @@ class MoeRunnerBackend(Enum):
|
|
44
46
|
AUTO = "auto"
|
45
47
|
TRITON = "triton"
|
46
48
|
TRITON_KERNEL = "triton_kernel"
|
47
|
-
|
49
|
+
FLASHINFER_TRTLLM = "flashinfer_trtllm"
|
48
50
|
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
49
51
|
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
52
|
+
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
|
50
53
|
|
51
54
|
def is_auto(self):
|
52
55
|
return self == MoeRunnerBackend.AUTO
|
@@ -58,11 +61,14 @@ class MoeRunnerBackend(Enum):
|
|
58
61
|
return self == MoeRunnerBackend.TRITON_KERNEL
|
59
62
|
|
60
63
|
def is_flashinfer_trtllm(self):
|
61
|
-
return self == MoeRunnerBackend.
|
64
|
+
return self == MoeRunnerBackend.FLASHINFER_TRTLLM
|
62
65
|
|
63
66
|
def is_flashinfer_cutlass(self):
|
64
67
|
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
|
65
68
|
|
69
|
+
def is_flashinfer_cutedsl(self):
|
70
|
+
return self == MoeRunnerBackend.FLASHINFER_CUTEDSL
|
71
|
+
|
66
72
|
def is_flashinfer_mxfp4(self):
|
67
73
|
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
68
74
|
|
@@ -131,7 +137,7 @@ def get_moe_a2a_backend() -> MoeA2ABackend:
|
|
131
137
|
global MOE_A2A_BACKEND
|
132
138
|
if MOE_A2A_BACKEND is None:
|
133
139
|
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
134
|
-
MOE_A2A_BACKEND = MoeA2ABackend
|
140
|
+
MOE_A2A_BACKEND = MoeA2ABackend.NONE
|
135
141
|
return MOE_A2A_BACKEND
|
136
142
|
|
137
143
|
|
@@ -139,7 +145,7 @@ def get_moe_runner_backend() -> MoeRunnerBackend:
|
|
139
145
|
global MOE_RUNNER_BACKEND
|
140
146
|
if MOE_RUNNER_BACKEND is None:
|
141
147
|
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
142
|
-
MOE_RUNNER_BACKEND = MoeRunnerBackend
|
148
|
+
MOE_RUNNER_BACKEND = MoeRunnerBackend.AUTO
|
143
149
|
return MOE_RUNNER_BACKEND
|
144
150
|
|
145
151
|
|
@@ -147,7 +153,7 @@ def get_deepep_mode() -> DeepEPMode:
|
|
147
153
|
global DEEPEP_MODE
|
148
154
|
if DEEPEP_MODE is None:
|
149
155
|
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
150
|
-
DEEPEP_MODE = DeepEPMode
|
156
|
+
DEEPEP_MODE = DeepEPMode.AUTO
|
151
157
|
return DEEPEP_MODE
|
152
158
|
|
153
159
|
|
@@ -34,7 +34,10 @@ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_param
|
|
34
34
|
|
35
35
|
if TYPE_CHECKING:
|
36
36
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
|
-
from sglang.srt.layers.moe.
|
37
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
38
|
+
StandardDispatchOutput,
|
39
|
+
CombineInput,
|
40
|
+
)
|
38
41
|
|
39
42
|
from sglang.srt.utils import is_cuda, is_hip
|
40
43
|
|
@@ -736,24 +739,32 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
736
739
|
)
|
737
740
|
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)
|
738
741
|
|
742
|
+
def create_moe_runner(
|
743
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
744
|
+
):
|
745
|
+
self.moe_runner_config = moe_runner_config
|
746
|
+
|
739
747
|
def apply(
|
740
748
|
self,
|
741
749
|
layer: torch.nn.Module,
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
750
|
+
dispatch_output: StandardDispatchOutput,
|
751
|
+
) -> CombineInput:
|
752
|
+
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
753
|
+
|
746
754
|
assert (
|
747
|
-
moe_runner_config.activation == "silu"
|
755
|
+
self.moe_runner_config.activation == "silu"
|
748
756
|
), "Only SiLU activation is supported."
|
749
757
|
|
750
758
|
# The input must currently be float16
|
759
|
+
x = dispatch_output.hidden_states
|
760
|
+
topk_output = dispatch_output.topk_output
|
761
|
+
|
751
762
|
orig_dtype = x.dtype
|
752
763
|
x = x.half()
|
753
764
|
|
754
765
|
topk_weights, topk_ids, router_logits = topk_output
|
755
766
|
|
756
|
-
|
767
|
+
output = fused_marlin_moe(
|
757
768
|
x,
|
758
769
|
layer.w13_qweight,
|
759
770
|
layer.w2_qweight,
|
@@ -768,3 +779,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
768
779
|
w2_zeros=layer.w2_qzeros,
|
769
780
|
num_bits=self.quant_config.weight_bits,
|
770
781
|
).to(orig_dtype)
|
782
|
+
return StandardCombineInput(hidden_states=output)
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
|
4
4
|
import inspect
|
5
5
|
from abc import ABC, abstractmethod
|
6
|
+
from dataclasses import dataclass
|
6
7
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
7
8
|
|
8
9
|
import torch
|
@@ -10,7 +11,7 @@ from torch import nn
|
|
10
11
|
|
11
12
|
if TYPE_CHECKING:
|
12
13
|
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
13
|
-
from sglang.srt.layers.moe.
|
14
|
+
from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
|
14
15
|
|
15
16
|
|
16
17
|
class QuantizeMethodBase(ABC):
|
@@ -89,20 +90,24 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
89
90
|
layer: torch.nn.Module,
|
90
91
|
num_experts: int,
|
91
92
|
hidden_size: int,
|
92
|
-
|
93
|
+
intermediate_size_per_partition: int,
|
93
94
|
params_dtype: torch.dtype,
|
94
95
|
**extra_weight_attrs,
|
95
96
|
):
|
96
97
|
raise NotImplementedError
|
97
98
|
|
99
|
+
@abstractmethod
|
100
|
+
def create_moe_runner(
|
101
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
102
|
+
):
|
103
|
+
raise NotImplementedError
|
104
|
+
|
98
105
|
@abstractmethod
|
99
106
|
def apply(
|
100
107
|
self,
|
101
108
|
layer: torch.nn.Module,
|
102
|
-
|
103
|
-
|
104
|
-
moe_runner_config: MoeRunnerConfig,
|
105
|
-
) -> torch.Tensor:
|
109
|
+
dispatch_output: DispatchOutput,
|
110
|
+
) -> CombineInput:
|
106
111
|
raise NotImplementedError
|
107
112
|
|
108
113
|
|
@@ -9,6 +9,8 @@ import torch
|
|
9
9
|
from torch.nn import Module
|
10
10
|
|
11
11
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
12
|
+
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
13
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
12
14
|
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
13
15
|
from sglang.srt.layers.quantization.base_config import (
|
14
16
|
FusedMoEMethodBase,
|
@@ -22,8 +24,10 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
|
|
22
24
|
from sglang.srt.utils import set_weight_attrs
|
23
25
|
|
24
26
|
if TYPE_CHECKING:
|
25
|
-
from sglang.srt.layers.moe.
|
26
|
-
|
27
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
28
|
+
CombineInput,
|
29
|
+
StandardDispatchOutput,
|
30
|
+
)
|
27
31
|
|
28
32
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
29
33
|
|
@@ -257,7 +261,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
257
261
|
layer: Module,
|
258
262
|
num_experts: int,
|
259
263
|
hidden_size: int,
|
260
|
-
|
264
|
+
intermediate_size_per_partition: int,
|
261
265
|
params_dtype: torch.dtype,
|
262
266
|
**extra_weight_attrs,
|
263
267
|
):
|
@@ -273,25 +277,28 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
273
277
|
)
|
274
278
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
275
279
|
# Required by column parallel or enabling merged weights
|
276
|
-
if
|
280
|
+
if intermediate_size_per_partition % block_n != 0:
|
277
281
|
raise ValueError(
|
278
282
|
f"The output_size of gate's and up's weight = "
|
279
|
-
f"{
|
283
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
280
284
|
f"weight quantization block_n = {block_n}."
|
281
285
|
)
|
282
286
|
if tp_size > 1:
|
283
287
|
# Required by row parallel
|
284
|
-
if
|
288
|
+
if intermediate_size_per_partition % block_k != 0:
|
285
289
|
raise ValueError(
|
286
290
|
f"The input_size of down's weight = "
|
287
|
-
f"{
|
291
|
+
f"{intermediate_size_per_partition} is not divisible by "
|
288
292
|
f"weight quantization block_k = {block_k}."
|
289
293
|
)
|
290
294
|
|
291
295
|
# WEIGHTS
|
292
296
|
w13_weight = torch.nn.Parameter(
|
293
297
|
torch.empty(
|
294
|
-
num_experts,
|
298
|
+
num_experts,
|
299
|
+
2 * intermediate_size_per_partition,
|
300
|
+
hidden_size,
|
301
|
+
dtype=params_dtype,
|
295
302
|
),
|
296
303
|
requires_grad=False,
|
297
304
|
)
|
@@ -300,7 +307,10 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
300
307
|
|
301
308
|
w2_weight = torch.nn.Parameter(
|
302
309
|
torch.empty(
|
303
|
-
num_experts,
|
310
|
+
num_experts,
|
311
|
+
hidden_size,
|
312
|
+
intermediate_size_per_partition,
|
313
|
+
dtype=params_dtype,
|
304
314
|
),
|
305
315
|
requires_grad=False,
|
306
316
|
)
|
@@ -311,7 +321,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
311
321
|
w13_weight_scale = torch.nn.Parameter(
|
312
322
|
torch.ones(
|
313
323
|
num_experts,
|
314
|
-
2 * ((
|
324
|
+
2 * ((intermediate_size_per_partition + block_n - 1) // block_n),
|
315
325
|
(hidden_size + block_k - 1) // block_k,
|
316
326
|
dtype=torch.float32,
|
317
327
|
),
|
@@ -321,7 +331,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
321
331
|
torch.ones(
|
322
332
|
num_experts,
|
323
333
|
(hidden_size + block_n - 1) // block_n,
|
324
|
-
(
|
334
|
+
(intermediate_size_per_partition + block_k - 1) // block_k,
|
325
335
|
dtype=torch.float32,
|
326
336
|
),
|
327
337
|
requires_grad=False,
|
@@ -344,26 +354,27 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
|
344
354
|
# Block quant doesn't need to process weights after loading
|
345
355
|
return
|
346
356
|
|
357
|
+
def create_moe_runner(
|
358
|
+
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
359
|
+
):
|
360
|
+
self.moe_runner_config = moe_runner_config
|
361
|
+
self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
|
362
|
+
|
347
363
|
def apply(
|
348
364
|
self,
|
349
365
|
layer: torch.nn.Module,
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
# Expert fusion with INT8 quantization
|
357
|
-
return fused_experts(
|
358
|
-
x,
|
359
|
-
layer.w13_weight,
|
360
|
-
layer.w2_weight,
|
361
|
-
topk_output=topk_output,
|
362
|
-
moe_runner_config=moe_runner_config,
|
366
|
+
dispatch_output: StandardDispatchOutput,
|
367
|
+
) -> CombineInput:
|
368
|
+
|
369
|
+
quant_info = TritonMoeQuantInfo(
|
370
|
+
w13_weight=layer.w13_weight,
|
371
|
+
w2_weight=layer.w2_weight,
|
363
372
|
use_int8_w8a8=True,
|
364
|
-
|
365
|
-
w2_scale=
|
366
|
-
|
373
|
+
w13_scale=layer.w13_weight_scale_inv,
|
374
|
+
w2_scale=layer.w2_weight_scale_inv,
|
375
|
+
a13_scale=layer.w13_input_scale,
|
367
376
|
a2_scale=layer.w2_input_scale,
|
368
377
|
block_shape=self.quant_config.weight_block_size,
|
369
378
|
)
|
379
|
+
|
380
|
+
return self.runner.run(dispatch_output, quant_info)
|