sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,41 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
1
4
|
from dataclasses import dataclass
|
2
|
-
from typing import Optional
|
5
|
+
from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.moe.moe_runner.triton import (
|
13
|
+
TritonRunnerCore,
|
14
|
+
TritonRunnerInput,
|
15
|
+
TritonRunnerOutput,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
18
|
+
CombineInput,
|
19
|
+
CombineInputFormat,
|
20
|
+
DispatchOutput,
|
21
|
+
DispatchOutputFormat,
|
22
|
+
)
|
3
23
|
|
4
24
|
|
5
25
|
@dataclass
|
6
26
|
class MoeRunnerConfig:
|
27
|
+
|
28
|
+
# MoE parameters
|
29
|
+
num_experts: Optional[int] = None
|
30
|
+
num_local_experts: Optional[int] = None
|
31
|
+
hidden_size: Optional[int] = None
|
32
|
+
intermediate_size_per_partition: Optional[int] = None
|
33
|
+
layer_id: Optional[int] = None
|
34
|
+
top_k: Optional[int] = None
|
35
|
+
num_fused_shared_experts: Optional[int] = None
|
36
|
+
params_dtype: Optional[torch.dtype] = None
|
37
|
+
|
38
|
+
# Runner configuration
|
7
39
|
activation: str = "silu"
|
8
40
|
apply_router_weight_on_input: bool = False
|
9
41
|
inplace: bool = True
|
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
|
|
11
43
|
routed_scaling_factor: Optional[float] = None
|
12
44
|
gemm1_alpha: Optional[float] = None
|
13
45
|
gemm1_clamp_limit: Optional[float] = None
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class RunnerInput(ABC):
|
50
|
+
|
51
|
+
@property
|
52
|
+
@abstractmethod
|
53
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
54
|
+
|
55
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
|
56
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
57
|
+
|
58
|
+
|
59
|
+
class RunnerOutput(ABC):
|
60
|
+
|
61
|
+
@property
|
62
|
+
@abstractmethod
|
63
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
64
|
+
|
65
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
|
66
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
67
|
+
|
68
|
+
|
69
|
+
@dataclass
|
70
|
+
class MoeQuantInfo(ABC):
|
71
|
+
"""Moe quantization data."""
|
72
|
+
|
73
|
+
pass
|
74
|
+
|
75
|
+
|
76
|
+
class MoeRunnerCore(ABC):
|
77
|
+
|
78
|
+
def __init__(self, config: MoeRunnerConfig):
|
79
|
+
self.config = config
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def run(
|
83
|
+
self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
|
84
|
+
) -> RunnerOutput:
|
85
|
+
pass
|
86
|
+
|
87
|
+
@property
|
88
|
+
@abstractmethod
|
89
|
+
def runner_backend(self) -> MoeRunnerBackend: ...
|
90
|
+
|
91
|
+
def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
|
92
|
+
return self.runner_backend == MoeRunnerBackend.TRITON
|
93
|
+
|
94
|
+
|
95
|
+
class FusedOpPool:
|
96
|
+
|
97
|
+
_fused_funcs: dict[str, Callable] = {}
|
98
|
+
|
99
|
+
@classmethod
|
100
|
+
def register_fused_func(
|
101
|
+
cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
|
102
|
+
):
|
103
|
+
key = (a2a_backend_name, runner_backend_name)
|
104
|
+
if key in cls._fused_funcs:
|
105
|
+
raise ValueError(
|
106
|
+
f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
|
107
|
+
)
|
108
|
+
assert MoeA2ABackend(
|
109
|
+
a2a_backend_name
|
110
|
+
), f"Invalid dispatch name: {a2a_backend_name}"
|
111
|
+
assert MoeRunnerBackend(
|
112
|
+
runner_backend_name
|
113
|
+
), f"Invalid runner name: {runner_backend_name}"
|
114
|
+
cls._fused_funcs[key] = fused_func
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
|
118
|
+
key = (dispatch_name, runner_name)
|
119
|
+
fused_func = cls._fused_funcs.get(key)
|
120
|
+
return fused_func
|
121
|
+
|
122
|
+
|
123
|
+
class PermuteMethodPool:
|
124
|
+
|
125
|
+
_pre_permute_methods: dict[
|
126
|
+
Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
|
127
|
+
] = {}
|
128
|
+
_post_permute_methods: dict[
|
129
|
+
Tuple[MoeRunnerBackend, CombineInputFormat], Callable
|
130
|
+
] = {}
|
131
|
+
|
132
|
+
@classmethod
|
133
|
+
def register_pre_permute(
|
134
|
+
cls,
|
135
|
+
dispatch_output_name: str,
|
136
|
+
runner_backend_name: str,
|
137
|
+
permute_func: Callable,
|
138
|
+
):
|
139
|
+
"""
|
140
|
+
Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
141
|
+
|
142
|
+
:param dispatch_output_name: The DispatchOutputFormat name.
|
143
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
144
|
+
:param permute_func: The permute function to register.
|
145
|
+
"""
|
146
|
+
# TODO: check if registration is valid
|
147
|
+
key = (dispatch_output_name, runner_backend_name)
|
148
|
+
if key in cls._pre_permute_methods:
|
149
|
+
raise ValueError(
|
150
|
+
f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
|
151
|
+
)
|
152
|
+
cls._pre_permute_methods[key] = permute_func
|
153
|
+
|
154
|
+
@classmethod
|
155
|
+
def register_post_permute(
|
156
|
+
cls,
|
157
|
+
runner_backend_name: str,
|
158
|
+
combine_input_name: str,
|
159
|
+
permute_func: Callable,
|
160
|
+
):
|
161
|
+
"""
|
162
|
+
Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
163
|
+
|
164
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
165
|
+
:param combine_input_name: The CombineInputFormat name.
|
166
|
+
:param permute_func: The permute function to register.
|
167
|
+
"""
|
168
|
+
# TODO: check if registration is valid
|
169
|
+
key = (runner_backend_name, combine_input_name)
|
170
|
+
if key in cls._post_permute_methods:
|
171
|
+
raise ValueError(
|
172
|
+
f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
|
173
|
+
)
|
174
|
+
cls._post_permute_methods[key] = permute_func
|
175
|
+
|
176
|
+
@classmethod
|
177
|
+
def get_pre_permute(
|
178
|
+
cls,
|
179
|
+
dispatch_output_format: DispatchOutputFormat,
|
180
|
+
runner_input_format: MoeRunnerBackend,
|
181
|
+
) -> Callable:
|
182
|
+
"""
|
183
|
+
Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
184
|
+
|
185
|
+
:param dispatch_output_format: The DispatchOutputFormat type.
|
186
|
+
:param runner_input_format: The MoeRunnerBackend type.
|
187
|
+
:return: The registered permute function or None if not found.
|
188
|
+
"""
|
189
|
+
key = (dispatch_output_format, runner_input_format)
|
190
|
+
pre_permute_func = cls._pre_permute_methods.get(key)
|
191
|
+
assert (
|
192
|
+
pre_permute_func is not None
|
193
|
+
), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
|
194
|
+
return pre_permute_func
|
195
|
+
|
196
|
+
@classmethod
|
197
|
+
def get_post_permute(
|
198
|
+
cls,
|
199
|
+
runner_output_format: MoeRunnerBackend,
|
200
|
+
combine_input_format: CombineInputFormat,
|
201
|
+
) -> Callable:
|
202
|
+
"""
|
203
|
+
Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
204
|
+
|
205
|
+
:param runner_output_format: The MoeRunnerBackend type.
|
206
|
+
:param combine_input_format: The CombineInputFormat type.
|
207
|
+
:return: The registered permute function or None if not found.
|
208
|
+
"""
|
209
|
+
key = (runner_output_format, combine_input_format)
|
210
|
+
post_permute_func = cls._post_permute_methods.get(key)
|
211
|
+
assert (
|
212
|
+
post_permute_func is not None
|
213
|
+
), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
|
214
|
+
return post_permute_func
|
215
|
+
|
216
|
+
|
217
|
+
def register_fused_func(
|
218
|
+
a2a_backend_name: str,
|
219
|
+
runner_backend_name: str,
|
220
|
+
) -> Callable:
|
221
|
+
"""
|
222
|
+
Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
|
223
|
+
|
224
|
+
:param a2a_backend_name: The A2A backend name.
|
225
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
226
|
+
:return: The decorator function.
|
227
|
+
"""
|
228
|
+
|
229
|
+
def decorator(fused_func: Callable):
|
230
|
+
FusedOpPool.register_fused_func(
|
231
|
+
a2a_backend_name, runner_backend_name, fused_func
|
232
|
+
)
|
233
|
+
return fused_func
|
234
|
+
|
235
|
+
return decorator
|
236
|
+
|
237
|
+
|
238
|
+
def register_pre_permute(
|
239
|
+
dispatch_output_name: str,
|
240
|
+
runner_backend_name: str,
|
241
|
+
) -> Callable:
|
242
|
+
"""
|
243
|
+
Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
|
244
|
+
|
245
|
+
:param dispatch_output_name: The DispatchOutputFormat name.
|
246
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
247
|
+
:return: The decorator function.
|
248
|
+
"""
|
249
|
+
|
250
|
+
def decorator(
|
251
|
+
permute_func: Callable[
|
252
|
+
[DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
|
253
|
+
]
|
254
|
+
) -> Callable:
|
255
|
+
|
256
|
+
PermuteMethodPool.register_pre_permute(
|
257
|
+
dispatch_output_name, runner_backend_name, permute_func
|
258
|
+
)
|
259
|
+
return permute_func
|
260
|
+
|
261
|
+
return decorator
|
262
|
+
|
263
|
+
|
264
|
+
def register_post_permute(
|
265
|
+
runner_backend_name: str,
|
266
|
+
combine_input_name: str,
|
267
|
+
) -> Callable:
|
268
|
+
"""
|
269
|
+
Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
|
270
|
+
|
271
|
+
:param runner_backend_name: The MoeRunnerBackend name.
|
272
|
+
:param combine_input_name: The CombineInputFormat name.
|
273
|
+
:return: The decorator function.
|
274
|
+
"""
|
275
|
+
|
276
|
+
def decorator(
|
277
|
+
permute_func: Callable[
|
278
|
+
[RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
|
279
|
+
]
|
280
|
+
) -> Callable:
|
281
|
+
PermuteMethodPool.register_post_permute(
|
282
|
+
runner_backend_name, combine_input_name, permute_func
|
283
|
+
)
|
284
|
+
return permute_func
|
285
|
+
|
286
|
+
return decorator
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.moe_runner.base import (
|
8
|
+
FusedOpPool,
|
9
|
+
MoeRunnerConfig,
|
10
|
+
PermuteMethodPool,
|
11
|
+
)
|
12
|
+
from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
|
13
|
+
from sglang.srt.layers.moe.utils import get_moe_a2a_backend
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
|
17
|
+
from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
|
18
|
+
from sglang.srt.layers.moe.utils import MoeRunnerBackend
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class MoeRunner:
|
24
|
+
|
25
|
+
def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
|
26
|
+
self.runner_backend = runner_backend
|
27
|
+
self.config = config
|
28
|
+
|
29
|
+
self.fused_func = None
|
30
|
+
|
31
|
+
if runner_backend.is_triton():
|
32
|
+
self.runner_core = TritonRunnerCore(config)
|
33
|
+
else:
|
34
|
+
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
|
35
|
+
|
36
|
+
a2a_backend_name = get_moe_a2a_backend().value
|
37
|
+
runner_backend_name = runner_backend.value
|
38
|
+
|
39
|
+
self.fused_func = FusedOpPool.get_fused_func(
|
40
|
+
a2a_backend_name, runner_backend_name
|
41
|
+
)
|
42
|
+
|
43
|
+
SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
|
44
|
+
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
|
45
|
+
)
|
46
|
+
if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
|
47
|
+
logger.info(
|
48
|
+
"SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
|
49
|
+
)
|
50
|
+
self.fused_func = None
|
51
|
+
|
52
|
+
def run(
|
53
|
+
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
|
54
|
+
) -> CombineInput:
|
55
|
+
|
56
|
+
if self.fused_func is not None:
|
57
|
+
return self.fused_func(dispatch_output, quant_info, self.config)
|
58
|
+
|
59
|
+
dispatch_format = dispatch_output.format.value
|
60
|
+
runner_format = self.runner_core.runner_backend.value
|
61
|
+
self.pre_permute_func = PermuteMethodPool.get_pre_permute(
|
62
|
+
dispatch_format, runner_format
|
63
|
+
)
|
64
|
+
|
65
|
+
running_state = {}
|
66
|
+
runner_input = self.pre_permute_func(
|
67
|
+
dispatch_output, quant_info, self.config, running_state
|
68
|
+
)
|
69
|
+
runner_output = self.runner_core.run(runner_input, quant_info, running_state)
|
70
|
+
|
71
|
+
runner_format = self.runner_core.runner_backend.value
|
72
|
+
combine_format = dispatch_output.format.value
|
73
|
+
self.post_permute_func = PermuteMethodPool.get_post_permute(
|
74
|
+
runner_format, combine_format
|
75
|
+
)
|
76
|
+
combine_input = self.post_permute_func(
|
77
|
+
runner_output, quant_info, self.config, running_state
|
78
|
+
)
|
79
|
+
|
80
|
+
return combine_input
|