sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,448 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, List, Optional
7
+
8
+ import torch
9
+ import triton.language as tl
10
+
11
+ from sglang.srt.layers.moe.moe_runner.base import (
12
+ MoeQuantInfo,
13
+ MoeRunnerConfig,
14
+ MoeRunnerCore,
15
+ RunnerInput,
16
+ RunnerOutput,
17
+ register_fused_func,
18
+ register_post_permute,
19
+ register_pre_permute,
20
+ )
21
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
22
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
26
+ StandardCombineInput,
27
+ StandardDispatchOutput,
28
+ )
29
+
30
+
31
+ _is_hip = is_hip()
32
+ _is_cuda = is_cuda()
33
+ _is_cpu_amx_available = cpu_has_amx_support()
34
+ _is_cpu = is_cpu()
35
+ _use_aiter = bool(int(os.getenv("SGLANG_MOE_USE_AITER", "0")))
36
+ _MOE_PADDING_SIZE = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
37
+
38
+
39
+ if _is_cuda:
40
+ from sgl_kernel import gelu_and_mul, silu_and_mul
41
+ elif _is_cpu and _is_cpu_amx_available:
42
+ pass
43
+ elif _is_hip:
44
+ from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
45
+
46
+ if _use_aiter:
47
+ try:
48
+ from aiter import moe_sum
49
+ except ImportError:
50
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
51
+
52
+
53
+ if _is_cuda or _is_hip:
54
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
55
+
56
+
57
+ @dataclass
58
+ class TritonRunnerInput(RunnerInput):
59
+
60
+ hidden_states: torch.Tensor
61
+ topk_weights: torch.Tensor
62
+ topk_ids: torch.Tensor
63
+ sorted_token_ids: torch.Tensor
64
+ expert_ids: torch.Tensor
65
+ num_tokens_post_padded: torch.Tensor
66
+
67
+ @property
68
+ def runner_backend(self) -> MoeRunnerBackend:
69
+ return MoeRunnerBackend.TRITON
70
+
71
+
72
+ @dataclass
73
+ class TritonRunnerOutput(RunnerOutput):
74
+
75
+ hidden_states: torch.Tensor
76
+
77
+ @property
78
+ def runner_backend(self) -> MoeRunnerBackend:
79
+ return MoeRunnerBackend.TRITON
80
+
81
+
82
+ @dataclass
83
+ class TritonMoeQuantInfo(MoeQuantInfo):
84
+ w13_weight: torch.Tensor
85
+ w2_weight: torch.Tensor
86
+ b13: Optional[torch.Tensor] = None
87
+ b2: Optional[torch.Tensor] = None
88
+ use_fp8_w8a8: bool = False
89
+ use_int8_w8a8: bool = False
90
+ use_int8_w8a16: bool = False
91
+ use_int4_w4a16: bool = False
92
+ per_channel_quant: bool = False
93
+ w13_scale: Optional[torch.Tensor] = None
94
+ w2_scale: Optional[torch.Tensor] = None
95
+ w13_zp: Optional[torch.Tensor] = None
96
+ w2_zp: Optional[torch.Tensor] = None
97
+ a13_scale: Optional[torch.Tensor] = None
98
+ a2_scale: Optional[torch.Tensor] = None
99
+ block_shape: Optional[List[int]] = None
100
+
101
+
102
+ class TritonRunnerCore(MoeRunnerCore):
103
+
104
+ def __init__(self, config: MoeRunnerConfig):
105
+ super().__init__(config)
106
+
107
+ def run(
108
+ self,
109
+ runner_input: TritonRunnerInput,
110
+ quant_info: TritonMoeQuantInfo,
111
+ running_state: dict,
112
+ ) -> TritonRunnerOutput:
113
+
114
+ # TODO: move these functions to the triton runner
115
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
116
+ invoke_fused_moe_kernel,
117
+ moe_sum_reduce_torch_compile,
118
+ moe_sum_reduce_triton,
119
+ swiglu_with_alpha_and_limit,
120
+ )
121
+
122
+ hidden_states = runner_input.hidden_states
123
+ topk_weights = runner_input.topk_weights
124
+ topk_ids = runner_input.topk_ids
125
+ sorted_token_ids = runner_input.sorted_token_ids
126
+ expert_ids = runner_input.expert_ids
127
+ num_tokens_post_padded = runner_input.num_tokens_post_padded
128
+
129
+ w13 = quant_info.w13_weight
130
+ w2 = quant_info.w2_weight
131
+ b13 = quant_info.b13
132
+ b2 = quant_info.b2
133
+ a13_scale = quant_info.a13_scale
134
+ a2_scale = quant_info.a2_scale
135
+ w13_scale = quant_info.w13_scale
136
+ w2_scale = quant_info.w2_scale
137
+ w13_zp = quant_info.w13_zp
138
+ w2_zp = quant_info.w2_zp
139
+ block_shape = quant_info.block_shape
140
+ per_channel_quant = quant_info.per_channel_quant
141
+ use_fp8_w8a8 = quant_info.use_fp8_w8a8
142
+ use_int8_w8a8 = quant_info.use_int8_w8a8
143
+ use_int8_w8a16 = quant_info.use_int8_w8a16
144
+ use_int4_w4a16 = quant_info.use_int4_w4a16
145
+
146
+ activation = self.config.activation
147
+ no_combine = self.config.no_combine
148
+ inplace = self.config.inplace
149
+ gemm1_alpha = self.config.gemm1_alpha
150
+ gemm1_limit = self.config.gemm1_clamp_limit
151
+ routed_scaling_factor = self.config.routed_scaling_factor
152
+ apply_router_weight_on_input = self.config.apply_router_weight_on_input
153
+
154
+ M = hidden_states.shape[0]
155
+ E, N, _ = w13.shape
156
+ compute_type = (
157
+ tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
158
+ )
159
+
160
+ intermediate_cache1 = torch.empty(
161
+ (M, topk_ids.shape[1], N),
162
+ device=hidden_states.device,
163
+ dtype=hidden_states.dtype,
164
+ )
165
+
166
+ invoke_fused_moe_kernel(
167
+ hidden_states,
168
+ w13,
169
+ b13,
170
+ intermediate_cache1,
171
+ a13_scale,
172
+ w13_scale,
173
+ w13_zp,
174
+ topk_weights,
175
+ topk_ids,
176
+ sorted_token_ids,
177
+ expert_ids,
178
+ num_tokens_post_padded,
179
+ apply_router_weight_on_input,
180
+ topk_ids.shape[1],
181
+ running_state["config"],
182
+ compute_type=compute_type,
183
+ use_fp8_w8a8=use_fp8_w8a8,
184
+ use_int8_w8a8=use_int8_w8a8,
185
+ use_int8_w8a16=use_int8_w8a16,
186
+ use_int4_w4a16=use_int4_w4a16,
187
+ per_channel_quant=per_channel_quant,
188
+ block_shape=block_shape,
189
+ )
190
+
191
+ intermediate_cache2 = torch.empty(
192
+ (M * topk_ids.shape[1], N // 2),
193
+ device=hidden_states.device,
194
+ dtype=hidden_states.dtype,
195
+ )
196
+
197
+ if activation == "silu":
198
+ if gemm1_alpha is not None:
199
+ assert gemm1_limit is not None
200
+ intermediate_cache2 = swiglu_with_alpha_and_limit(
201
+ intermediate_cache1.view(-1, N),
202
+ gemm1_alpha,
203
+ gemm1_limit,
204
+ )
205
+ elif _is_cuda:
206
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
207
+ else:
208
+ vllm_ops.silu_and_mul(
209
+ intermediate_cache2, intermediate_cache1.view(-1, N)
210
+ )
211
+ elif activation == "gelu":
212
+ assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
213
+ assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
214
+ if _is_cuda:
215
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
216
+ else:
217
+ vllm_ops.gelu_and_mul(
218
+ intermediate_cache2, intermediate_cache1.view(-1, N)
219
+ )
220
+ else:
221
+ raise ValueError(f"Unsupported activation: {activation=}")
222
+
223
+ intermediate_cache3 = torch.empty(
224
+ (M, topk_ids.shape[1], w2.shape[1]),
225
+ device=hidden_states.device,
226
+ dtype=hidden_states.dtype,
227
+ )
228
+
229
+ if no_combine:
230
+ assert not inplace
231
+ out_hidden_states = torch.empty(
232
+ (M, topk_ids.shape[1], w2.shape[1]),
233
+ device=hidden_states.device,
234
+ dtype=hidden_states.dtype,
235
+ )
236
+ elif inplace:
237
+ out_hidden_states = hidden_states
238
+ else:
239
+ out_hidden_states = torch.empty_like(hidden_states)
240
+
241
+ invoke_fused_moe_kernel(
242
+ intermediate_cache2,
243
+ w2,
244
+ b2,
245
+ (
246
+ intermediate_cache3
247
+ if not no_combine and topk_ids.shape[1] != 1
248
+ else out_hidden_states.unsqueeze(0)
249
+ ),
250
+ a2_scale,
251
+ w2_scale,
252
+ w2_zp,
253
+ topk_weights,
254
+ topk_ids,
255
+ sorted_token_ids,
256
+ expert_ids,
257
+ num_tokens_post_padded,
258
+ not apply_router_weight_on_input,
259
+ 1,
260
+ running_state["config"],
261
+ compute_type=compute_type,
262
+ use_fp8_w8a8=use_fp8_w8a8,
263
+ use_int8_w8a8=use_int8_w8a8,
264
+ use_int8_w8a16=use_int8_w8a16,
265
+ use_int4_w4a16=use_int4_w4a16,
266
+ per_channel_quant=per_channel_quant,
267
+ block_shape=block_shape,
268
+ )
269
+
270
+ if routed_scaling_factor is None:
271
+ routed_scaling_factor = 1.0
272
+
273
+ if no_combine:
274
+ pass
275
+ elif _is_cuda:
276
+ if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
277
+ pass # we write directly into out_hidden_states
278
+ elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
279
+ torch.add(
280
+ intermediate_cache3[:, 0],
281
+ intermediate_cache3[:, 1],
282
+ out=out_hidden_states,
283
+ ).squeeze(dim=1)
284
+ else:
285
+ # According to micro benchmark results, torch.compile can get better performance for small token.
286
+ if M <= 32:
287
+ moe_sum_reduce_torch_compile(
288
+ intermediate_cache3.view(*intermediate_cache3.shape),
289
+ out_hidden_states,
290
+ routed_scaling_factor,
291
+ )
292
+ else:
293
+ moe_sum_reduce_triton(
294
+ intermediate_cache3.view(*intermediate_cache3.shape),
295
+ out_hidden_states,
296
+ routed_scaling_factor,
297
+ )
298
+ elif _is_hip:
299
+ if _use_aiter:
300
+ moe_sum(
301
+ intermediate_cache3.view(*intermediate_cache3.shape),
302
+ out_hidden_states,
303
+ )
304
+ else:
305
+ vllm_ops.moe_sum(
306
+ intermediate_cache3.view(*intermediate_cache3.shape),
307
+ out_hidden_states,
308
+ )
309
+ else:
310
+ vllm_ops.moe_sum(
311
+ intermediate_cache3.view(*intermediate_cache3.shape),
312
+ out_hidden_states,
313
+ )
314
+
315
+ return TritonRunnerOutput(
316
+ hidden_states=out_hidden_states,
317
+ )
318
+
319
+ @property
320
+ def runner_backend(self) -> MoeRunnerBackend:
321
+ return MoeRunnerBackend.TRITON
322
+
323
+
324
+ @register_fused_func("none", "triton")
325
+ def fused_experts_none_to_triton(
326
+ dispatch_output: StandardDispatchOutput,
327
+ quant_info: TritonMoeQuantInfo,
328
+ runner_config: MoeRunnerConfig,
329
+ ) -> StandardCombineInput:
330
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
331
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
332
+
333
+ output = fused_experts(
334
+ hidden_states=dispatch_output.hidden_states,
335
+ w1=quant_info.w13_weight,
336
+ w2=quant_info.w2_weight,
337
+ topk_output=dispatch_output.topk_output,
338
+ moe_runner_config=runner_config,
339
+ b1=quant_info.b13,
340
+ b2=quant_info.b2,
341
+ use_fp8_w8a8=quant_info.use_fp8_w8a8,
342
+ use_int8_w8a8=quant_info.use_int8_w8a8,
343
+ use_int8_w8a16=quant_info.use_int8_w8a16,
344
+ use_int4_w4a16=quant_info.use_int4_w4a16,
345
+ per_channel_quant=quant_info.per_channel_quant,
346
+ w1_scale=quant_info.w13_scale,
347
+ w2_scale=quant_info.w2_scale,
348
+ w1_zp=quant_info.w13_zp,
349
+ w2_zp=quant_info.w2_zp,
350
+ a1_scale=quant_info.a13_scale,
351
+ a2_scale=quant_info.a2_scale,
352
+ block_shape=quant_info.block_shape,
353
+ )
354
+
355
+ return StandardCombineInput(
356
+ hidden_states=output,
357
+ )
358
+
359
+
360
+ @register_pre_permute("standard", "triton")
361
+ def pre_permute_standard_to_triton(
362
+ dispatch_output: StandardDispatchOutput,
363
+ quant_info: TritonMoeQuantInfo,
364
+ runner_config: MoeRunnerConfig,
365
+ running_state: dict,
366
+ ) -> TritonRunnerInput:
367
+
368
+ # NOTE: this is dead code as a fused func for standard format is registered.
369
+ # This is left here for testing and examples.
370
+
371
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
372
+ get_config_dtype_str,
373
+ moe_align_block_size,
374
+ try_get_optimal_moe_config,
375
+ )
376
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
377
+
378
+ hidden_states, topk_output = dispatch_output
379
+
380
+ assert TopKOutputChecker.format_is_standard(topk_output)
381
+
382
+ num_tokens = hidden_states.shape[0]
383
+ num_local_experts = runner_config.num_local_experts
384
+
385
+ if (
386
+ not (quant_info.use_fp8_w8a8 or quant_info.use_int8_w8a8)
387
+ or quant_info.block_shape is not None
388
+ or _use_aiter
389
+ ):
390
+ padding_size = 0
391
+ else:
392
+ padding_size = _MOE_PADDING_SIZE
393
+
394
+ config_dtype = get_config_dtype_str(
395
+ use_fp8_w8a8=quant_info.use_fp8_w8a8,
396
+ use_int8_w8a8=quant_info.use_int8_w8a8,
397
+ use_int8_w8a16=quant_info.use_int8_w8a16,
398
+ use_int4_w4a16=quant_info.use_int4_w4a16,
399
+ dtype=hidden_states.dtype,
400
+ )
401
+
402
+ get_config_func = functools.partial(
403
+ try_get_optimal_moe_config,
404
+ quant_info.w13_weight.shape,
405
+ (
406
+ num_local_experts,
407
+ quant_info.w2_weight.shape[1],
408
+ quant_info.w2_weight.shape[2] - padding_size,
409
+ ),
410
+ topk_output.topk_ids.shape[1],
411
+ config_dtype,
412
+ block_shape=quant_info.block_shape,
413
+ )
414
+
415
+ config = get_config_func(num_tokens)
416
+
417
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
418
+ topk_output.topk_ids, config["BLOCK_SIZE_M"], num_local_experts
419
+ )
420
+
421
+ running_state["config"] = config
422
+
423
+ return TritonRunnerInput(
424
+ hidden_states=hidden_states,
425
+ topk_weights=topk_output.topk_weights,
426
+ topk_ids=topk_output.topk_ids,
427
+ sorted_token_ids=sorted_token_ids,
428
+ expert_ids=expert_ids,
429
+ num_tokens_post_padded=num_tokens_post_padded,
430
+ )
431
+
432
+
433
+ @register_post_permute("triton", "standard")
434
+ def post_permute_triton_to_standard(
435
+ runner_output: TritonRunnerOutput,
436
+ quant_info: TritonMoeQuantInfo,
437
+ runner_config: MoeRunnerConfig,
438
+ running_state: dict,
439
+ ) -> StandardCombineInput:
440
+
441
+ # NOTE: this is dead code as a fused func for standard format is registered.
442
+ # This is left here for testing and examples.
443
+
444
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
445
+
446
+ return StandardCombineInput(
447
+ hidden_states=runner_output.hidden_states,
448
+ )
@@ -1,29 +1,41 @@
1
- from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
1
+ from sglang.srt.layers.moe.token_dispatcher.base import (
2
2
  BaseDispatcher,
3
3
  BaseDispatcherConfig,
4
+ CombineInput,
5
+ CombineInputChecker,
6
+ CombineInputFormat,
4
7
  DispatchOutput,
5
8
  DispatchOutputChecker,
6
9
  DispatchOutputFormat,
7
10
  )
8
11
  from sglang.srt.layers.moe.token_dispatcher.deepep import (
9
- AscendDeepEPLLOutput,
10
12
  DeepEPConfig,
11
13
  DeepEPDispatcher,
14
+ DeepEPLLCombineInput,
12
15
  DeepEPLLOutput,
16
+ DeepEPNormalCombineInput,
13
17
  DeepEPNormalOutput,
14
18
  )
15
- from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
19
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
20
+ StandardCombineInput,
21
+ StandardDispatchOutput,
22
+ )
16
23
 
17
24
  __all__ = [
18
- "AscendDeepEPLLOutput",
19
25
  "BaseDispatcher",
20
26
  "BaseDispatcherConfig",
27
+ "CombineInput",
28
+ "CombineInputChecker",
29
+ "CombineInputFormat",
21
30
  "DispatchOutput",
22
31
  "DispatchOutputFormat",
23
32
  "DispatchOutputChecker",
24
33
  "StandardDispatchOutput",
34
+ "StandardCombineInput",
25
35
  "DeepEPConfig",
26
36
  "DeepEPDispatcher",
27
37
  "DeepEPNormalOutput",
28
38
  "DeepEPLLOutput",
39
+ "DeepEPLLCombineInput",
40
+ "DeepEPNormalCombineInput",
29
41
  ]
@@ -1,18 +1,23 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from abc import ABC, abstractmethod
4
- from enum import Enum, auto
4
+ from enum import Enum
5
5
  from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
6
6
 
7
7
  import torch
8
8
 
9
9
  if TYPE_CHECKING:
10
10
  from sglang.srt.layers.moe.token_dispatcher import (
11
- AscendDeepEPLLOutput,
11
+ DeepEPLLCombineInput,
12
12
  DeepEPLLOutput,
13
+ DeepEPNormalCombineInput,
13
14
  DeepEPNormalOutput,
15
+ StandardCombineInput,
14
16
  StandardDispatchOutput,
15
17
  )
18
+ from sglang.srt.layers.moe.topk import TopKOutput
19
+
20
+ # ------------------------------ Dispatch Output -------------------------------------
16
21
 
17
22
 
18
23
  class DispatchOutputChecker:
@@ -41,19 +46,12 @@ class DispatchOutputChecker:
41
46
  ) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
42
47
  return dispatch_output.format.is_deepep()
43
48
 
44
- @staticmethod
45
- def format_is_ascent_ll(
46
- dispatch_output: DispatchOutput,
47
- ) -> TypeGuard[AscendDeepEPLLOutput]:
48
- return dispatch_output.format.is_ascent_ll()
49
-
50
49
 
51
50
  class DispatchOutputFormat(Enum):
52
51
 
53
- STANDARD = auto()
54
- DEEPEP_NORMAL = auto()
55
- DEEPEP_LL = auto()
56
- ASCENT_LL = auto()
52
+ STANDARD = "standard"
53
+ DEEPEP_NORMAL = "deepep_normal"
54
+ DEEPEP_LL = "deepep_ll"
57
55
 
58
56
  def is_standard(self) -> bool:
59
57
  return self == DispatchOutputFormat.STANDARD
@@ -70,18 +68,68 @@ class DispatchOutputFormat(Enum):
70
68
  DispatchOutputFormat.DEEPEP_LL,
71
69
  ]
72
70
 
73
- def is_ascent_ll(self) -> bool:
74
- return self == DispatchOutputFormat.ASCENT_LL
75
-
76
71
 
77
72
  @runtime_checkable
78
73
  class DispatchOutput(Protocol):
79
74
  """Protocol for dispatch outputs in different formats."""
80
75
 
76
+ # TODO: add hidden_states to the protocol
77
+
81
78
  @property
82
79
  def format(self) -> DispatchOutputFormat: ...
83
80
 
84
81
 
82
+ # ------------------------------ Combine Input -------------------------------------
83
+
84
+
85
+ class CombineInputChecker:
86
+ @staticmethod
87
+ def format_is_standard(
88
+ combine_input: CombineInput,
89
+ ) -> TypeGuard[StandardCombineInput]:
90
+ return combine_input.format == CombineInputFormat.STANDARD
91
+
92
+ @staticmethod
93
+ def format_is_deepep_normal(
94
+ combine_input: CombineInput,
95
+ ) -> TypeGuard[DeepEPNormalCombineInput]:
96
+ return combine_input.format == CombineInputFormat.DEEPEP_NORMAL
97
+
98
+ @staticmethod
99
+ def format_is_deepep_ll(
100
+ combine_input: CombineInput,
101
+ ) -> TypeGuard[DeepEPLLCombineInput]:
102
+ return combine_input.format == CombineInputFormat.DEEPEP_LL
103
+
104
+ @staticmethod
105
+ def format_is_deepep(
106
+ combine_input: CombineInput,
107
+ ) -> TypeGuard[Union[DeepEPNormalCombineInput, DeepEPLLCombineInput]]:
108
+ return combine_input.format in [
109
+ CombineInputFormat.DEEPEP_NORMAL,
110
+ CombineInputFormat.DEEPEP_LL,
111
+ ]
112
+
113
+
114
+ class CombineInputFormat(Enum):
115
+ STANDARD = "standard"
116
+ DEEPEP_NORMAL = "deepep_normal"
117
+ DEEPEP_LL = "deepep_ll"
118
+
119
+
120
+ @runtime_checkable
121
+ class CombineInput(Protocol):
122
+ """Protocol for combine inputs in different formats."""
123
+
124
+ # TODO: add hidden_states to the protocol
125
+
126
+ @property
127
+ def format(self) -> CombineInputFormat: ...
128
+
129
+
130
+ # ------------------------------ Base Dispatcher -------------------------------------
131
+
132
+
85
133
  class BaseDispatcherConfig(ABC):
86
134
  """Base class for dispatcher configs."""
87
135
 
@@ -92,9 +140,11 @@ class BaseDispatcher(ABC):
92
140
  """Base class for dispatchers."""
93
141
 
94
142
  @abstractmethod
95
- def dispatch(self, *args, **kwargs) -> DispatchOutput:
143
+ def dispatch(
144
+ self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs
145
+ ) -> DispatchOutput:
96
146
  pass
97
147
 
98
148
  @abstractmethod
99
- def combine(self, *args, **kwargs) -> torch.Tensor:
149
+ def combine(self, combine_input: CombineInput, **kwargs) -> torch.Tensor:
100
150
  pass