sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,41 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
1
4
  from dataclasses import dataclass
2
- from typing import Optional
5
+ from typing import TYPE_CHECKING, Callable, Optional, Tuple, TypeGuard
6
+
7
+ import torch
8
+
9
+ from sglang.srt.layers.moe.utils import MoeA2ABackend, MoeRunnerBackend
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.moe_runner.triton import (
13
+ TritonRunnerCore,
14
+ TritonRunnerInput,
15
+ TritonRunnerOutput,
16
+ )
17
+ from sglang.srt.layers.moe.token_dispatcher import (
18
+ CombineInput,
19
+ CombineInputFormat,
20
+ DispatchOutput,
21
+ DispatchOutputFormat,
22
+ )
3
23
 
4
24
 
5
25
  @dataclass
6
26
  class MoeRunnerConfig:
27
+
28
+ # MoE parameters
29
+ num_experts: Optional[int] = None
30
+ num_local_experts: Optional[int] = None
31
+ hidden_size: Optional[int] = None
32
+ intermediate_size_per_partition: Optional[int] = None
33
+ layer_id: Optional[int] = None
34
+ top_k: Optional[int] = None
35
+ num_fused_shared_experts: Optional[int] = None
36
+ params_dtype: Optional[torch.dtype] = None
37
+
38
+ # Runner configuration
7
39
  activation: str = "silu"
8
40
  apply_router_weight_on_input: bool = False
9
41
  inplace: bool = True
@@ -11,3 +43,244 @@ class MoeRunnerConfig:
11
43
  routed_scaling_factor: Optional[float] = None
12
44
  gemm1_alpha: Optional[float] = None
13
45
  gemm1_clamp_limit: Optional[float] = None
46
+
47
+
48
+ @dataclass
49
+ class RunnerInput(ABC):
50
+
51
+ @property
52
+ @abstractmethod
53
+ def runner_backend(self) -> MoeRunnerBackend: ...
54
+
55
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerInput]:
56
+ return self.runner_backend == MoeRunnerBackend.TRITON
57
+
58
+
59
+ class RunnerOutput(ABC):
60
+
61
+ @property
62
+ @abstractmethod
63
+ def runner_backend(self) -> MoeRunnerBackend: ...
64
+
65
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerOutput]:
66
+ return self.runner_backend == MoeRunnerBackend.TRITON
67
+
68
+
69
+ @dataclass
70
+ class MoeQuantInfo(ABC):
71
+ """Moe quantization data."""
72
+
73
+ pass
74
+
75
+
76
+ class MoeRunnerCore(ABC):
77
+
78
+ def __init__(self, config: MoeRunnerConfig):
79
+ self.config = config
80
+
81
+ @abstractmethod
82
+ def run(
83
+ self, runner_input: RunnerInput, quant_info: MoeQuantInfo, running_state: dict
84
+ ) -> RunnerOutput:
85
+ pass
86
+
87
+ @property
88
+ @abstractmethod
89
+ def runner_backend(self) -> MoeRunnerBackend: ...
90
+
91
+ def runner_backend_is_triton(self) -> TypeGuard[TritonRunnerCore]:
92
+ return self.runner_backend == MoeRunnerBackend.TRITON
93
+
94
+
95
+ class FusedOpPool:
96
+
97
+ _fused_funcs: dict[str, Callable] = {}
98
+
99
+ @classmethod
100
+ def register_fused_func(
101
+ cls, a2a_backend_name: str, runner_backend_name: str, fused_func: Callable
102
+ ):
103
+ key = (a2a_backend_name, runner_backend_name)
104
+ if key in cls._fused_funcs:
105
+ raise ValueError(
106
+ f"Fused function for {a2a_backend_name} to {runner_backend_name} is already registered."
107
+ )
108
+ assert MoeA2ABackend(
109
+ a2a_backend_name
110
+ ), f"Invalid dispatch name: {a2a_backend_name}"
111
+ assert MoeRunnerBackend(
112
+ runner_backend_name
113
+ ), f"Invalid runner name: {runner_backend_name}"
114
+ cls._fused_funcs[key] = fused_func
115
+
116
+ @classmethod
117
+ def get_fused_func(cls, dispatch_name: str, runner_name: str) -> Optional[Callable]:
118
+ key = (dispatch_name, runner_name)
119
+ fused_func = cls._fused_funcs.get(key)
120
+ return fused_func
121
+
122
+
123
+ class PermuteMethodPool:
124
+
125
+ _pre_permute_methods: dict[
126
+ Tuple[DispatchOutputFormat, MoeRunnerBackend], Callable
127
+ ] = {}
128
+ _post_permute_methods: dict[
129
+ Tuple[MoeRunnerBackend, CombineInputFormat], Callable
130
+ ] = {}
131
+
132
+ @classmethod
133
+ def register_pre_permute(
134
+ cls,
135
+ dispatch_output_name: str,
136
+ runner_backend_name: str,
137
+ permute_func: Callable,
138
+ ):
139
+ """
140
+ Register a customized pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
141
+
142
+ :param dispatch_output_name: The DispatchOutputFormat name.
143
+ :param runner_backend_name: The MoeRunnerBackend name.
144
+ :param permute_func: The permute function to register.
145
+ """
146
+ # TODO: check if registration is valid
147
+ key = (dispatch_output_name, runner_backend_name)
148
+ if key in cls._pre_permute_methods:
149
+ raise ValueError(
150
+ f"Pre-permute method for {dispatch_output_name} to {runner_backend_name} is already registered."
151
+ )
152
+ cls._pre_permute_methods[key] = permute_func
153
+
154
+ @classmethod
155
+ def register_post_permute(
156
+ cls,
157
+ runner_backend_name: str,
158
+ combine_input_name: str,
159
+ permute_func: Callable,
160
+ ):
161
+ """
162
+ Register a customized post-permute function for the given MoeRunnerBackend and CombineInputFormat.
163
+
164
+ :param runner_backend_name: The MoeRunnerBackend name.
165
+ :param combine_input_name: The CombineInputFormat name.
166
+ :param permute_func: The permute function to register.
167
+ """
168
+ # TODO: check if registration is valid
169
+ key = (runner_backend_name, combine_input_name)
170
+ if key in cls._post_permute_methods:
171
+ raise ValueError(
172
+ f"Post-permute method for {runner_backend_name} to {combine_input_name} is already registered."
173
+ )
174
+ cls._post_permute_methods[key] = permute_func
175
+
176
+ @classmethod
177
+ def get_pre_permute(
178
+ cls,
179
+ dispatch_output_format: DispatchOutputFormat,
180
+ runner_input_format: MoeRunnerBackend,
181
+ ) -> Callable:
182
+ """
183
+ Retrieve the pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
184
+
185
+ :param dispatch_output_format: The DispatchOutputFormat type.
186
+ :param runner_input_format: The MoeRunnerBackend type.
187
+ :return: The registered permute function or None if not found.
188
+ """
189
+ key = (dispatch_output_format, runner_input_format)
190
+ pre_permute_func = cls._pre_permute_methods.get(key)
191
+ assert (
192
+ pre_permute_func is not None
193
+ ), f"Pre-permute function for {dispatch_output_format} to {runner_input_format} is not registered"
194
+ return pre_permute_func
195
+
196
+ @classmethod
197
+ def get_post_permute(
198
+ cls,
199
+ runner_output_format: MoeRunnerBackend,
200
+ combine_input_format: CombineInputFormat,
201
+ ) -> Callable:
202
+ """
203
+ Retrieve the post-permute function for the given MoeRunnerBackend and CombineInputFormat.
204
+
205
+ :param runner_output_format: The MoeRunnerBackend type.
206
+ :param combine_input_format: The CombineInputFormat type.
207
+ :return: The registered permute function or None if not found.
208
+ """
209
+ key = (runner_output_format, combine_input_format)
210
+ post_permute_func = cls._post_permute_methods.get(key)
211
+ assert (
212
+ post_permute_func is not None
213
+ ), f"Post-permute function for {runner_output_format} to {combine_input_format} is not registered"
214
+ return post_permute_func
215
+
216
+
217
+ def register_fused_func(
218
+ a2a_backend_name: str,
219
+ runner_backend_name: str,
220
+ ) -> Callable:
221
+ """
222
+ Decorator to register a fused function for the given DispatchOutputFormat and MoeRunnerBackend.
223
+
224
+ :param a2a_backend_name: The A2A backend name.
225
+ :param runner_backend_name: The MoeRunnerBackend name.
226
+ :return: The decorator function.
227
+ """
228
+
229
+ def decorator(fused_func: Callable):
230
+ FusedOpPool.register_fused_func(
231
+ a2a_backend_name, runner_backend_name, fused_func
232
+ )
233
+ return fused_func
234
+
235
+ return decorator
236
+
237
+
238
+ def register_pre_permute(
239
+ dispatch_output_name: str,
240
+ runner_backend_name: str,
241
+ ) -> Callable:
242
+ """
243
+ Decorator to register a pre-permute function for the given DispatchOutputFormat and MoeRunnerBackend.
244
+
245
+ :param dispatch_output_name: The DispatchOutputFormat name.
246
+ :param runner_backend_name: The MoeRunnerBackend name.
247
+ :return: The decorator function.
248
+ """
249
+
250
+ def decorator(
251
+ permute_func: Callable[
252
+ [DispatchOutput, MoeQuantInfo, MoeRunnerConfig, dict], RunnerInput
253
+ ]
254
+ ) -> Callable:
255
+
256
+ PermuteMethodPool.register_pre_permute(
257
+ dispatch_output_name, runner_backend_name, permute_func
258
+ )
259
+ return permute_func
260
+
261
+ return decorator
262
+
263
+
264
+ def register_post_permute(
265
+ runner_backend_name: str,
266
+ combine_input_name: str,
267
+ ) -> Callable:
268
+ """
269
+ Decorator to register a post-permute function for the given MoeRunnerBackend and CombineInputFormat.
270
+
271
+ :param runner_backend_name: The MoeRunnerBackend name.
272
+ :param combine_input_name: The CombineInputFormat name.
273
+ :return: The decorator function.
274
+ """
275
+
276
+ def decorator(
277
+ permute_func: Callable[
278
+ [RunnerOutput, MoeQuantInfo, MoeRunnerConfig, dict], CombineInput
279
+ ]
280
+ ) -> Callable:
281
+ PermuteMethodPool.register_post_permute(
282
+ runner_backend_name, combine_input_name, permute_func
283
+ )
284
+ return permute_func
285
+
286
+ return decorator
@@ -0,0 +1,80 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from typing import TYPE_CHECKING
6
+
7
+ from sglang.srt.layers.moe.moe_runner.base import (
8
+ FusedOpPool,
9
+ MoeRunnerConfig,
10
+ PermuteMethodPool,
11
+ )
12
+ from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
13
+ from sglang.srt.layers.moe.utils import get_moe_a2a_backend
14
+
15
+ if TYPE_CHECKING:
16
+ from sglang.srt.layers.moe.moe_runner.base import MoeQuantInfo
17
+ from sglang.srt.layers.moe.token_dispatcher.base import CombineInput, DispatchOutput
18
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class MoeRunner:
24
+
25
+ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
26
+ self.runner_backend = runner_backend
27
+ self.config = config
28
+
29
+ self.fused_func = None
30
+
31
+ if runner_backend.is_triton():
32
+ self.runner_core = TritonRunnerCore(config)
33
+ else:
34
+ raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
35
+
36
+ a2a_backend_name = get_moe_a2a_backend().value
37
+ runner_backend_name = runner_backend.value
38
+
39
+ self.fused_func = FusedOpPool.get_fused_func(
40
+ a2a_backend_name, runner_backend_name
41
+ )
42
+
43
+ SGLANG_CI_DISABLE_MOE_FUSED_FUNC = os.environ.get(
44
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC", "0"
45
+ )
46
+ if SGLANG_CI_DISABLE_MOE_FUSED_FUNC == "1":
47
+ logger.info(
48
+ "SGLANG_CI_DISABLE_MOE_FUSED_FUNC is set to 1, disabling fused func"
49
+ )
50
+ self.fused_func = None
51
+
52
+ def run(
53
+ self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
54
+ ) -> CombineInput:
55
+
56
+ if self.fused_func is not None:
57
+ return self.fused_func(dispatch_output, quant_info, self.config)
58
+
59
+ dispatch_format = dispatch_output.format.value
60
+ runner_format = self.runner_core.runner_backend.value
61
+ self.pre_permute_func = PermuteMethodPool.get_pre_permute(
62
+ dispatch_format, runner_format
63
+ )
64
+
65
+ running_state = {}
66
+ runner_input = self.pre_permute_func(
67
+ dispatch_output, quant_info, self.config, running_state
68
+ )
69
+ runner_output = self.runner_core.run(runner_input, quant_info, running_state)
70
+
71
+ runner_format = self.runner_core.runner_backend.value
72
+ combine_format = dispatch_output.format.value
73
+ self.post_permute_func = PermuteMethodPool.get_post_permute(
74
+ runner_format, combine_format
75
+ )
76
+ combine_input = self.post_permute_func(
77
+ runner_output, quant_info, self.config, running_state
78
+ )
79
+
80
+ return combine_input