sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ # NOTE: this file will be separated into sglang/srt/layers/moe/moe_runner/triton_utils.py
1
2
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
2
3
 
3
4
  """Fused MoE kernel."""
@@ -5,39 +6,29 @@
5
6
  from __future__ import annotations
6
7
 
7
8
  import functools
8
- import json
9
- import logging
10
9
  import os
11
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
10
+ from typing import TYPE_CHECKING, List, Optional
12
11
 
13
12
  import torch
14
- import triton
15
13
  import triton.language as tl
16
14
 
17
15
  from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
18
- from sglang.srt.layers.moe.topk import StandardTopKOutput
19
- from sglang.srt.layers.quantization.fp8_kernel import (
20
- per_token_group_quant_fp8,
21
- scaled_fp8_quant,
22
- sglang_per_token_group_quant_fp8,
23
- )
24
- from sglang.srt.layers.quantization.int8_kernel import (
25
- per_token_group_quant_int8,
26
- per_token_quant_int8,
27
- sglang_per_token_group_quant_int8,
28
- )
29
16
  from sglang.srt.utils import (
30
- ceil_div,
31
17
  cpu_has_amx_support,
32
18
  direct_register_custom_op,
33
19
  get_bool_env_var,
34
- get_device_name,
35
20
  is_cpu,
36
21
  is_cuda,
37
22
  is_hip,
38
- next_power_of_2,
39
23
  )
40
24
 
25
+ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
26
+ from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
27
+ from .moe_align_block_size import moe_align_block_size
28
+
29
+ if TYPE_CHECKING:
30
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
31
+
41
32
  _is_hip = is_hip()
42
33
  _is_cuda = is_cuda()
43
34
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -59,954 +50,9 @@ elif _is_hip:
59
50
  else:
60
51
  from vllm import _custom_ops as vllm_ops
61
52
 
62
-
63
- if _is_cuda or _is_hip:
64
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
65
-
66
-
67
- logger = logging.getLogger(__name__)
68
53
  padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
69
54
 
70
55
 
71
- @triton.jit
72
- def write_zeros_to_output(
73
- c_ptr,
74
- stride_cm,
75
- stride_cn,
76
- pid_n,
77
- N,
78
- offs_token,
79
- token_mask,
80
- BLOCK_SIZE_M,
81
- BLOCK_SIZE_N,
82
- compute_type,
83
- ):
84
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
85
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
86
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
87
- c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
88
- tl.store(c_ptrs, accumulator, mask=c_mask)
89
-
90
-
91
- @triton.jit
92
- def fused_moe_kernel_gptq_awq(
93
- # Pointers to matrices
94
- a_ptr,
95
- b_ptr,
96
- c_ptr,
97
- b_scale_ptr,
98
- b_zp_ptr,
99
- topk_weights_ptr,
100
- sorted_token_ids_ptr,
101
- expert_ids_ptr,
102
- num_tokens_post_padded_ptr,
103
- # Matrix dimensions
104
- N: tl.constexpr,
105
- K: tl.constexpr,
106
- EM,
107
- num_valid_tokens,
108
- # The stride variables represent how much to increase the ptr by when
109
- # moving by 1 element in a particular dimension. E.g. `stride_am` is
110
- # how much to increase `a_ptr` by to get the element one row down
111
- # (A has M rows).
112
- stride_am,
113
- stride_ak,
114
- stride_be,
115
- stride_bk,
116
- stride_bn,
117
- stride_cm,
118
- stride_cn,
119
- stride_bse,
120
- stride_bsk,
121
- stride_bsn,
122
- stride_bze,
123
- stride_bzk,
124
- stride_bzn,
125
- group_size: tl.constexpr,
126
- # Meta-parameters
127
- BLOCK_SIZE_M: tl.constexpr,
128
- BLOCK_SIZE_N: tl.constexpr,
129
- BLOCK_SIZE_K: tl.constexpr,
130
- GROUP_SIZE_M: tl.constexpr,
131
- MUL_ROUTED_WEIGHT: tl.constexpr,
132
- top_k: tl.constexpr,
133
- compute_type: tl.constexpr,
134
- has_zp: tl.constexpr,
135
- use_int4_w4a16: tl.constexpr,
136
- use_int8_w8a16: tl.constexpr,
137
- even_Ks: tl.constexpr,
138
- ):
139
- """
140
- Implements the fused computation for a Mixture of Experts (MOE) using
141
- token and expert matrices.
142
- Key Parameters:
143
- - A: The input tensor representing tokens with shape (*, K), where '*' can
144
- be any shape representing batches and K is the feature dimension of
145
- each token.
146
- - B: The stacked MOE weight tensor with shape (E, N, K), where E is
147
- the number of experts, K is the input feature dimension, and N is
148
- the output feature dimension.
149
- - C: The output cache tensor with shape (M, topk, N), where M is the
150
- total number of tokens post padding, topk is the number of times
151
- each token is repeated, and N is the output feature dimension.
152
- - sorted_token_ids: A tensor containing the sorted indices of tokens,
153
- repeated topk times and arranged by the expert index they are
154
- assigned to.
155
- - expert_ids: A tensor containing the indices of the expert for each
156
- block. It determines which expert matrix from B should be used for
157
- each block in A.
158
- This kernel performs the multiplication of a token by its corresponding
159
- expert matrix as determined by `expert_ids`. The sorting of
160
- `sorted_token_ids` by expert index and padding ensures divisibility by
161
- BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
162
- multiplication across different blocks processed by the same expert.
163
- """
164
- # -----------------------------------------------------------
165
- # Map program ids `pid` to the block of C it should compute.
166
- # This is done in a grouped ordering to promote L2 data reuse.
167
- pid = tl.program_id(axis=0)
168
- num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
169
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
170
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
171
- group_id = pid // num_pid_in_group
172
- first_pid_m = group_id * GROUP_SIZE_M
173
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
174
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
175
- pid_n = (pid % num_pid_in_group) // group_size_m
176
-
177
- # ----------------------------------------------------------
178
- # Create pointers for the first blocks of A and B.
179
- # We will advance this pointer as we move in the K direction
180
- # and accumulate
181
- # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
182
- # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
183
- num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
184
- if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
185
- return
186
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
187
- offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
188
- token_mask = offs_token < num_valid_tokens
189
-
190
- off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
191
- if off_experts == -1:
192
- # -----------------------------------------------------------
193
- # Write back zeros to the output when the expert is not
194
- # in the current expert parallel rank.
195
- write_zeros_to_output(
196
- c_ptr,
197
- stride_cm,
198
- stride_cn,
199
- pid_n,
200
- N,
201
- offs_token,
202
- token_mask,
203
- BLOCK_SIZE_M,
204
- BLOCK_SIZE_N,
205
- compute_type,
206
- )
207
- return
208
-
209
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
210
- offs_k = tl.arange(0, BLOCK_SIZE_K)
211
- a_ptrs = a_ptr + (
212
- offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
213
- )
214
-
215
- if use_int4_w4a16:
216
- b_ptrs = (
217
- b_ptr
218
- + off_experts * stride_be
219
- + (offs_k[:, None] // 2) * stride_bk
220
- + offs_bn[None, :] * stride_bn
221
- )
222
- b_shifter = (offs_k[:, None] % 2) * 4
223
- elif use_int8_w8a16:
224
- b_ptrs = (
225
- b_ptr
226
- + off_experts * stride_be
227
- + offs_k[:, None] * stride_bk
228
- + offs_bn[None, :] * stride_bn
229
- )
230
-
231
- if not has_zp and use_int4_w4a16:
232
- b_zp_num = 8
233
- if not has_zp and use_int8_w8a16:
234
- b_zp_num = 128
235
- elif has_zp and use_int4_w4a16:
236
- b_zp_shifter = (offs_bn[None, :] % 2) * 4
237
-
238
- # -----------------------------------------------------------
239
- # Iterate to compute a block of the C matrix.
240
- # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
241
- # of fp32 values for higher accuracy.
242
- # `accumulator` will be converted back to fp16 after the loop.
243
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
244
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
245
- # Load the next block of A and B, generate a mask by checking the
246
- # K dimension.
247
-
248
- if not even_Ks:
249
- k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
250
- k_other = 0.0
251
- else:
252
- k_mask = None
253
- k_other = None
254
-
255
- a = tl.load(
256
- a_ptrs,
257
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
258
- other=0.0,
259
- )
260
- b = tl.load(b_ptrs)
261
- if use_int4_w4a16:
262
- b = (b >> b_shifter) & 0xF
263
-
264
- b_scale_ptrs = (
265
- b_scale_ptr
266
- + off_experts * stride_bse
267
- + offs_bn[None, :] * stride_bsn
268
- + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
269
- )
270
- b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
271
- b_scale = b_scale.to(tl.float32)
272
-
273
- if has_zp and use_int4_w4a16:
274
- offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
275
- b_zp_ptrs = (
276
- b_zp_ptr
277
- + off_experts * stride_bze
278
- + (offs_bn[None, :] // 2) * stride_bzn
279
- + offs_k_true * stride_bzk
280
- )
281
- b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
282
- b_zp = (b_zp >> b_zp_shifter) & 0xF
283
- b_zp = b_zp.to(tl.float32)
284
- elif has_zp and use_int8_w8a16:
285
- offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
286
- b_zp_ptrs = (
287
- b_zp_ptr
288
- + off_experts * stride_bze
289
- + offs_bn[None, :] * stride_bzn
290
- + offs_k_true * stride_bzk
291
- )
292
- b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
293
- b_zp = b_zp.to(tl.float32)
294
-
295
- # We accumulate along the K dimension.
296
- if has_zp:
297
- b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
298
- else:
299
- b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
300
- accumulator = tl.dot(a, b, acc=accumulator)
301
-
302
- # Advance the ptrs to the next K block.
303
- a_ptrs += BLOCK_SIZE_K * stride_ak
304
- if use_int4_w4a16:
305
- b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
306
- else:
307
- b_ptrs += BLOCK_SIZE_K * stride_bk
308
-
309
- if MUL_ROUTED_WEIGHT:
310
- moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
311
- accumulator = accumulator * moe_weight[:, None]
312
-
313
- accumulator = accumulator.to(compute_type)
314
- # -----------------------------------------------------------
315
- # Write back the block of the output
316
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
317
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
318
- c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
319
- tl.store(c_ptrs, accumulator, mask=c_mask)
320
-
321
-
322
- @triton.jit
323
- def fused_moe_kernel(
324
- # Pointers to matrices
325
- a_ptr,
326
- b_ptr,
327
- bias_ptr,
328
- c_ptr,
329
- a_scale_ptr,
330
- b_scale_ptr,
331
- topk_weights_ptr,
332
- sorted_token_ids_ptr,
333
- expert_ids_ptr,
334
- num_tokens_post_padded_ptr,
335
- # Matrix dimensions
336
- N,
337
- K,
338
- EM,
339
- num_valid_tokens,
340
- # The stride variables represent how much to increase the ptr by when
341
- # moving by 1 element in a particular dimension. E.g. `stride_am` is
342
- # how much to increase `a_ptr` by to get the element one row down
343
- # (A has M rows).
344
- stride_am,
345
- stride_ak,
346
- stride_be,
347
- stride_bk,
348
- stride_bn,
349
- stride_bias_e,
350
- stride_bias_n,
351
- stride_cm,
352
- stride_cn,
353
- stride_asm,
354
- stride_ask,
355
- stride_bse,
356
- stride_bsk,
357
- stride_bsn,
358
- # Block size for block-wise quantization
359
- group_n: tl.constexpr,
360
- group_k: tl.constexpr,
361
- # Meta-parameters
362
- BLOCK_SIZE_M: tl.constexpr,
363
- BLOCK_SIZE_N: tl.constexpr,
364
- BLOCK_SIZE_K: tl.constexpr,
365
- GROUP_SIZE_M: tl.constexpr,
366
- MUL_ROUTED_WEIGHT: tl.constexpr,
367
- top_k: tl.constexpr,
368
- compute_type: tl.constexpr,
369
- use_fp8_w8a8: tl.constexpr,
370
- use_int8_w8a8: tl.constexpr,
371
- use_int8_w8a16: tl.constexpr,
372
- per_channel_quant: tl.constexpr,
373
- even_Ks: tl.constexpr,
374
- ):
375
- """
376
- Implements the fused computation for a Mixture of Experts (MOE) using
377
- token and expert matrices.
378
-
379
- Key Parameters:
380
- - A: The input tensor representing tokens with shape (*, K), where '*' can
381
- be any shape representing batches and K is the feature dimension of
382
- each token.
383
- - B: The stacked MOE weight tensor with shape (E, N, K), where E is
384
- the number of experts, K is the input feature dimension, and N is
385
- the output feature dimension.
386
- - C: The output cache tensor with shape (M, topk, N), where M is the
387
- total number of tokens post padding, topk is the number of times
388
- each token is repeated, and N is the output feature dimension.
389
- - sorted_token_ids: A tensor containing the sorted indices of tokens,
390
- repeated topk times and arranged by the expert index they are
391
- assigned to.
392
- - expert_ids: A tensor containing the indices of the expert for each
393
- block. It determines which expert matrix from B should be used for
394
- each block in A.
395
-
396
- This kernel performs the multiplication of a token by its corresponding
397
- expert matrix as determined by `expert_ids`. The sorting of
398
- `sorted_token_ids` by expert index and padding ensures divisibility by
399
- BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
400
- multiplication across different blocks processed by the same expert.
401
- """
402
- # -----------------------------------------------------------
403
- # Map program ids `pid` to the block of C it should compute.
404
- # This is done in a grouped ordering to promote L2 data reuse.
405
- pid = tl.program_id(axis=0)
406
- num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
407
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
408
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
409
- group_id = pid // num_pid_in_group
410
- first_pid_m = group_id * GROUP_SIZE_M
411
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
412
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
413
- pid_n = (pid % num_pid_in_group) // group_size_m
414
-
415
- # ----------------------------------------------------------
416
- # Create pointers for the first blocks of A and B.
417
- # We will advance this pointer as we move in the K direction
418
- # and accumulate
419
- # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
420
- # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
421
- num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
422
- if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
423
- return
424
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
425
- offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
426
- offs_token = offs_token.to(tl.int64)
427
- token_mask = offs_token < num_valid_tokens
428
-
429
- off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
430
-
431
- if off_experts == -1:
432
- # -----------------------------------------------------------
433
- # Write back zeros to the output when the expert is not
434
- # in the current expert parallel rank.
435
- write_zeros_to_output(
436
- c_ptr,
437
- stride_cm,
438
- stride_cn,
439
- pid_n,
440
- N,
441
- offs_token,
442
- token_mask,
443
- BLOCK_SIZE_M,
444
- BLOCK_SIZE_N,
445
- compute_type,
446
- )
447
- return
448
-
449
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
450
- offs_k = tl.arange(0, BLOCK_SIZE_K)
451
- a_ptrs = a_ptr + (
452
- offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
453
- )
454
-
455
- b_ptrs = (
456
- b_ptr
457
- + off_experts * stride_be
458
- + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
459
- )
460
- if bias_ptr is not None:
461
- bias = tl.load(
462
- bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
463
- )
464
- if use_int8_w8a16:
465
- b_scale_ptrs = (
466
- b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
467
- )
468
- b_scale = tl.load(b_scale_ptrs)
469
-
470
- if use_fp8_w8a8 or use_int8_w8a8:
471
- # block-wise
472
- if group_k > 0 and group_n > 0:
473
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
474
- offs_bsn = offs_bn // group_n
475
- b_scale_ptrs = (
476
- b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
477
- )
478
- # channel-wise
479
- elif per_channel_quant:
480
- b_scale_ptrs = (
481
- b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
482
- )
483
- b_scale = tl.load(b_scale_ptrs)
484
- # Load per-token scale for activations
485
- a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
486
- a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
487
- # tensor-wise
488
- else:
489
- a_scale = tl.load(a_scale_ptr)
490
- b_scale = tl.load(b_scale_ptr + off_experts)
491
-
492
- # -----------------------------------------------------------
493
- # Iterate to compute a block of the C matrix.
494
- # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
495
- # of fp32 values for higher accuracy.
496
- # `accumulator` will be converted back to fp16 after the loop.
497
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
498
-
499
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
500
- # Load the next block of A and B, generate a mask by checking the
501
- # K dimension.
502
- if even_Ks:
503
- a = tl.load(
504
- a_ptrs,
505
- mask=token_mask[:, None],
506
- other=0.0,
507
- )
508
- b = tl.load(b_ptrs)
509
- else:
510
- a = tl.load(
511
- a_ptrs,
512
- mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
513
- other=0.0,
514
- )
515
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
516
-
517
- # We accumulate along the K dimension.
518
- if use_int8_w8a16:
519
- accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
520
- elif use_fp8_w8a8 or use_int8_w8a8:
521
- if group_k > 0 and group_n > 0:
522
- k_start = k * BLOCK_SIZE_K
523
- offs_ks = k_start // group_k
524
- a_scale = tl.load(
525
- a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
526
- )
527
- b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
528
-
529
- accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
530
- else:
531
- if use_fp8_w8a8:
532
- accumulator = tl.dot(a, b, acc=accumulator)
533
- else:
534
- accumulator += tl.dot(a, b)
535
- else:
536
- accumulator += tl.dot(a, b)
537
- # Advance the ptrs to the next K block.
538
- a_ptrs += BLOCK_SIZE_K * stride_ak
539
- b_ptrs += BLOCK_SIZE_K * stride_bk
540
-
541
- if use_int8_w8a16:
542
- accumulator *= b_scale
543
- elif use_fp8_w8a8 or use_int8_w8a8:
544
- if group_k == 0 or group_n == 0:
545
- accumulator *= a_scale * b_scale
546
-
547
- if bias_ptr is not None:
548
- accumulator += bias
549
-
550
- if MUL_ROUTED_WEIGHT:
551
- moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
552
- accumulator *= moe_weight[:, None]
553
-
554
- accumulator = accumulator.to(compute_type)
555
- # -----------------------------------------------------------
556
- # Write back the block of the output
557
- offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
558
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
559
- c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
560
- tl.store(c_ptrs, accumulator, mask=c_mask)
561
-
562
-
563
- def moe_align_block_size(
564
- topk_ids: torch.Tensor, block_size: int, num_experts: int
565
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
566
- """
567
- Aligns the token distribution across experts to be compatible with block
568
- size for matrix multiplication.
569
-
570
- Parameters:
571
- - topk_ids: A tensor of shape [total_tokens, top_k] representing the
572
- top-k expert indices for each token.
573
- - block_size: The block size used in block matrix multiplication.
574
- - num_experts: The total number of experts.
575
-
576
- Returns:
577
- - sorted_token_ids: A tensor containing the sorted token indices according
578
- to their allocated expert.
579
- - expert_ids: A tensor indicating the assigned expert index for each block.
580
- - num_tokens_post_padded: The total number of tokens after padding,
581
- ensuring divisibility by block_size.
582
-
583
- This function pads the number of tokens that each expert needs to process
584
- so that it is divisible by block_size.
585
- Padding ensures that during block matrix multiplication, the dimensions
586
- align correctly.
587
-
588
- Example:
589
- Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
590
- block_size = 4, and num_experts = 4:
591
- - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
592
- with each expert needing to process 3 tokens.
593
- - As block_size is 4, we pad 1 token for each expert.
594
- - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
595
- - Then append padding tokens [12, 12, 12, 12] for each block.
596
- - After sorting by expert index, we obtain token_ids
597
- [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
598
- Tokens 12 are non-existent (padding) and are ignored in
599
- the subsequent matrix multiplication.
600
- - The padding ensures that the total number of tokens is now divisible
601
- by block_size for proper block matrix operations.
602
- """
603
- max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
604
- sorted_ids = torch.empty(
605
- (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
606
- )
607
- max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
608
- expert_ids = torch.empty(
609
- (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
610
- )
611
- num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
612
-
613
- # In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
614
- cumsum_buffer = torch.empty(
615
- (num_experts + 2,), dtype=torch.int32, device=topk_ids.device
616
- )
617
-
618
- # Threshold based on benchmark results
619
- fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
620
- if not fuse_sorted_ids_padding:
621
- sorted_ids.fill_(topk_ids.numel())
622
-
623
- sgl_moe_align_block_size(
624
- topk_ids,
625
- num_experts + 1,
626
- block_size,
627
- sorted_ids,
628
- expert_ids,
629
- num_tokens_post_pad,
630
- cumsum_buffer,
631
- fuse_sorted_ids_padding,
632
- )
633
- return sorted_ids, expert_ids, num_tokens_post_pad
634
-
635
-
636
- def invoke_fused_moe_kernel(
637
- A: torch.Tensor,
638
- B: torch.Tensor,
639
- bias: Optional[torch.Tensor],
640
- C: torch.Tensor,
641
- A_scale: Optional[torch.Tensor],
642
- B_scale: Optional[torch.Tensor],
643
- B_zp: Optional[torch.Tensor],
644
- topk_weights: torch.Tensor,
645
- topk_ids: torch.Tensor,
646
- sorted_token_ids: torch.Tensor,
647
- expert_ids: torch.Tensor,
648
- num_tokens_post_padded: torch.Tensor,
649
- mul_routed_weight: bool,
650
- top_k: int,
651
- config: Dict[str, Any],
652
- compute_type: tl.dtype,
653
- use_fp8_w8a8: bool,
654
- use_int8_w8a8: bool,
655
- use_int8_w8a16: bool,
656
- use_int4_w4a16: bool,
657
- per_channel_quant: bool,
658
- block_shape: Optional[List[int]] = None,
659
- no_combine: bool = False,
660
- ) -> None:
661
- assert topk_weights.stride(1) == 1
662
- assert sorted_token_ids.stride(0) == 1
663
-
664
- padded_size = 0
665
- if use_fp8_w8a8:
666
- assert B_scale is not None
667
- if block_shape is None:
668
- # activation tensor-wise fp8 quantization, dynamic or static
669
- padded_size = padding_size
670
- # activations apply per-token quantization when weights apply per-channel quantization by default
671
- A, A_scale = scaled_fp8_quant(
672
- A, A_scale, use_per_token_if_dynamic=per_channel_quant
673
- )
674
- else:
675
- # activation block-wise fp8 quantization
676
- assert len(block_shape) == 2
677
- block_n, block_k = block_shape[0], block_shape[1]
678
- if _is_cuda:
679
- A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
680
- else:
681
- A, A_scale = per_token_group_quant_fp8(A, block_k)
682
- assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
683
- assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
684
- assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
685
- elif use_int8_w8a8:
686
- assert B_scale is not None
687
- if block_shape is None:
688
- # activation channel-wise int8 quantization
689
- assert (
690
- per_channel_quant
691
- ), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
692
- A, A_scale = per_token_quant_int8(A)
693
- else:
694
- # activation block-wise int8 quantization
695
- assert len(block_shape) == 2
696
- block_n, block_k = block_shape[0], block_shape[1]
697
- if _is_cuda:
698
- A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
699
- else:
700
- A, A_scale = per_token_group_quant_int8(A, block_k)
701
- assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
702
- assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
703
- assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
704
- elif use_int8_w8a16 or use_int4_w4a16:
705
- assert B_scale is not None
706
- assert block_shape is None or block_shape[0] == 0
707
- else:
708
- assert A_scale is None
709
- assert B_scale is None
710
-
711
- grid = lambda META: (
712
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
713
- * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
714
- )
715
-
716
- K = B.shape[2] - padded_size
717
- if K % config["BLOCK_SIZE_K"] == 0:
718
- even_Ks = True
719
- else:
720
- even_Ks = False
721
-
722
- if (
723
- (use_int8_w8a16 or use_int4_w4a16)
724
- and block_shape is not None
725
- and block_shape[1] > 0
726
- ):
727
- assert B_scale is not None and B_scale.ndim == 3
728
- assert B_zp is None or B_zp.ndim == 3
729
- assert bias is None
730
- fused_moe_kernel_gptq_awq[grid](
731
- A,
732
- B,
733
- C,
734
- B_scale,
735
- B_zp,
736
- topk_weights,
737
- sorted_token_ids,
738
- expert_ids,
739
- num_tokens_post_padded,
740
- B.shape[1],
741
- A.shape[1],
742
- sorted_token_ids.shape[0],
743
- topk_ids.numel(),
744
- A.stride(0),
745
- A.stride(1),
746
- B.stride(0),
747
- B.stride(2),
748
- B.stride(1),
749
- C.stride(1),
750
- C.stride(2),
751
- B_scale.stride(0),
752
- B_scale.stride(2),
753
- B_scale.stride(1),
754
- B_zp.stride(0) if B_zp is not None else 0,
755
- B_zp.stride(2) if B_zp is not None else 0,
756
- B_zp.stride(1) if B_zp is not None else 0,
757
- group_size=block_shape[1],
758
- MUL_ROUTED_WEIGHT=mul_routed_weight,
759
- top_k=top_k,
760
- compute_type=compute_type,
761
- has_zp=B_zp is not None,
762
- use_int4_w4a16=use_int4_w4a16,
763
- use_int8_w8a16=use_int8_w8a16,
764
- even_Ks=even_Ks,
765
- **config,
766
- )
767
-
768
- else:
769
-
770
- fused_moe_kernel[grid](
771
- A,
772
- B,
773
- bias,
774
- C,
775
- A_scale,
776
- B_scale,
777
- topk_weights,
778
- sorted_token_ids,
779
- expert_ids,
780
- num_tokens_post_padded,
781
- B.shape[1],
782
- B.shape[2] - padded_size,
783
- sorted_token_ids.shape[0],
784
- topk_ids.numel(),
785
- A.stride(0),
786
- A.stride(1),
787
- B.stride(0),
788
- B.stride(2),
789
- B.stride(1),
790
- bias.stride(0) if bias is not None else 0,
791
- bias.stride(1) if bias is not None else 0,
792
- C.stride(1),
793
- C.stride(2),
794
- A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
795
- A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
796
- B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
797
- B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
798
- B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
799
- 0 if block_shape is None else block_shape[0],
800
- 0 if block_shape is None else block_shape[1],
801
- MUL_ROUTED_WEIGHT=mul_routed_weight,
802
- top_k=top_k,
803
- compute_type=compute_type,
804
- use_fp8_w8a8=use_fp8_w8a8,
805
- use_int8_w8a8=use_int8_w8a8,
806
- use_int8_w8a16=use_int8_w8a16,
807
- per_channel_quant=per_channel_quant,
808
- even_Ks=even_Ks,
809
- **config,
810
- )
811
-
812
-
813
- def get_config_file_name(
814
- E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
815
- ) -> str:
816
- device_name = get_device_name().replace(" ", "_")
817
- dtype_selector = "" if not dtype else f",dtype={dtype}"
818
- block_shape_selector = (
819
- "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
820
- )
821
- return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
822
-
823
-
824
- @functools.lru_cache
825
- def get_moe_configs(
826
- E: int,
827
- N: int,
828
- dtype: Optional[str],
829
- block_n: Optional[int] = 0,
830
- block_k: Optional[int] = 0,
831
- ) -> Optional[Dict[int, Any]]:
832
- """
833
- Return optimized configurations for the fused MoE kernel.
834
-
835
- The return value will be a dictionary that maps an irregular grid of
836
- batch sizes to configurations of the fused_moe kernel. To evaluate the
837
- kernel on a given batch size bs, the closest batch size in the grid should
838
- be picked and the associated configuration chosen to invoke the kernel.
839
- """
840
- # Supported Triton versions, should be sorted from the newest to the oldest
841
- supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
842
-
843
- # First look up if an optimized configuration is available in the configs
844
- # directory
845
- json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
846
-
847
- # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
848
- # so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
849
- triton_version = triton.__version__
850
- version_dir = f"triton_{triton_version.replace('.', '_')}"
851
- config_file_path = os.path.join(
852
- os.path.dirname(os.path.realpath(__file__)),
853
- "configs",
854
- version_dir,
855
- json_file_name,
856
- )
857
- if os.path.exists(config_file_path):
858
- with open(config_file_path) as f:
859
- # Please note that although we find the config files, performance might still be suboptimal.
860
- # This is because the tuning environment might differ from your current environment.
861
- # For example, updating the Triton version might cause all old configs to become suboptimal.
862
- # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
863
- # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
864
- logger.info(f"Using MoE kernel config from {config_file_path}.")
865
- # If a configuration has been found, return it
866
- return {int(key): val for key, val in json.load(f).items()}
867
-
868
- # Searching for other triton versions that supports the same config
869
- for try_triton_version in supported_triton_versions:
870
- if try_triton_version == triton_version:
871
- continue
872
- try_config_file_path = os.path.join(
873
- os.path.dirname(os.path.realpath(__file__)),
874
- "configs",
875
- f"triton_{try_triton_version.replace('.', '_')}",
876
- json_file_name,
877
- )
878
- if os.path.exists(try_config_file_path):
879
- with open(try_config_file_path) as f:
880
- logger.warning(
881
- f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!",
882
- )
883
- # If a configuration has been found, return it
884
- return {int(key): val for key, val in json.load(f).items()}
885
-
886
- # If no optimized configuration is available, we will use the default
887
- # configuration
888
- logger.warning(
889
- (
890
- "Using default MoE kernel config. Performance might be sub-optimal! "
891
- "Config file not found at %s, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton"
892
- ),
893
- config_file_path,
894
- )
895
- return None
896
-
897
-
898
- def get_default_config(
899
- M: int,
900
- E: int,
901
- N: int,
902
- K: int,
903
- topk: int,
904
- dtype: Optional[str],
905
- is_marlin: bool,
906
- block_shape: Optional[List[int]] = None,
907
- ) -> Dict[str, int]:
908
- if dtype == "fp8_w8a8":
909
- if block_shape is None:
910
- config = {
911
- "BLOCK_SIZE_M": 128,
912
- "BLOCK_SIZE_N": 256,
913
- "BLOCK_SIZE_K": 128,
914
- "GROUP_SIZE_M": 32,
915
- "num_warps": 8,
916
- "num_stages": 2 if _is_hip else 4,
917
- }
918
- if M <= E:
919
- config = {
920
- "BLOCK_SIZE_M": 64,
921
- "BLOCK_SIZE_N": 128,
922
- "BLOCK_SIZE_K": 128,
923
- "GROUP_SIZE_M": 1,
924
- "num_warps": 4,
925
- "num_stages": 2 if _is_hip else 4,
926
- }
927
- else:
928
- # Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
929
- config = {
930
- "BLOCK_SIZE_M": 64,
931
- "BLOCK_SIZE_N": block_shape[0],
932
- "BLOCK_SIZE_K": block_shape[1],
933
- "GROUP_SIZE_M": 32,
934
- "num_warps": 4,
935
- "num_stages": 2 if _is_hip else 3,
936
- }
937
- else:
938
- config = {
939
- "BLOCK_SIZE_M": 64,
940
- "BLOCK_SIZE_N": 64,
941
- "BLOCK_SIZE_K": 32,
942
- "GROUP_SIZE_M": 8,
943
- }
944
- # A heuristic: fused marlin works faster with this config for small M
945
- if M <= E or (is_marlin and M <= 32):
946
- config = {
947
- "BLOCK_SIZE_M": 16,
948
- "BLOCK_SIZE_N": 32,
949
- "BLOCK_SIZE_K": 64,
950
- "GROUP_SIZE_M": 1,
951
- }
952
- return config
953
-
954
-
955
- def try_get_optimal_moe_config(
956
- w1_shape: Tuple[int, ...],
957
- w2_shape: Tuple[int, ...],
958
- top_k: int,
959
- dtype: Optional[str],
960
- M: int,
961
- is_marlin: bool = False,
962
- block_shape: Optional[List[int]] = None,
963
- ):
964
- from sglang.srt.layers.moe.fused_moe_triton import get_config
965
-
966
- override_config = get_config()
967
- if override_config:
968
- config = override_config
969
- else:
970
- # First try to load optimal config from the file
971
- E, _, N = w2_shape
972
- block_n = block_shape[0] if block_shape else 0
973
- block_k = block_shape[1] if block_shape else 0
974
- configs = get_moe_configs(E, N, dtype, block_n, block_k)
975
-
976
- if configs:
977
- # If an optimal configuration map has been found, look up the
978
- # optimal config
979
- config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
980
- else:
981
- # Else use the default config
982
- config = get_default_config(
983
- M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
984
- )
985
- return config
986
-
987
-
988
- def get_config_dtype_str(
989
- dtype: torch.dtype,
990
- use_int8_w8a16: Optional[bool] = False,
991
- use_int4_w4a16: Optional[bool] = False,
992
- use_fp8_w8a8: Optional[bool] = False,
993
- use_int8_w8a8: Optional[bool] = False,
994
- ):
995
- if use_fp8_w8a8:
996
- return "fp8_w8a8"
997
- elif use_int8_w8a8:
998
- return "int8_w8a8"
999
- elif use_int4_w4a16:
1000
- return "int4_w4a16"
1001
- elif use_int8_w8a16:
1002
- return "int8_w8a16"
1003
- elif dtype == torch.float:
1004
- # avoiding cases where kernel fails when float32 MoE
1005
- # use fp16/bfloat16 configs
1006
- return "float32"
1007
- return None
1008
-
1009
-
1010
56
  def inplace_fused_experts(
1011
57
  hidden_states: torch.Tensor,
1012
58
  w1: torch.Tensor,
@@ -1276,92 +322,6 @@ def fused_experts(
1276
322
  )
1277
323
 
1278
324
 
1279
- # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
1280
- @triton.jit
1281
- def _moe_sum_reduce_kernel(
1282
- input_ptr,
1283
- input_stride_0,
1284
- input_stride_1,
1285
- input_stride_2,
1286
- output_ptr,
1287
- output_stride_0,
1288
- output_stride_1,
1289
- token_num: int,
1290
- topk_num: int,
1291
- hidden_dim: int,
1292
- routed_scaling_factor: tl.constexpr,
1293
- BLOCK_M: tl.constexpr,
1294
- BLOCK_DIM: tl.constexpr,
1295
- NUM_STAGE: tl.constexpr,
1296
- ):
1297
- input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
1298
- input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
1299
- output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
1300
-
1301
- token_block_id = tl.program_id(0)
1302
- dim_block_id = tl.program_id(1)
1303
-
1304
- token_start = token_block_id * BLOCK_M
1305
- token_end = min((token_block_id + 1) * BLOCK_M, token_num)
1306
-
1307
- dim_start = dim_block_id * BLOCK_DIM
1308
- dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
1309
-
1310
- offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
1311
-
1312
- for token_index in range(token_start, token_end):
1313
- accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
1314
- input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
1315
- for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
1316
- tmp = tl.load(
1317
- input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
1318
- )
1319
- accumulator += tmp
1320
- accumulator = accumulator * routed_scaling_factor
1321
- store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
1322
- tl.store(
1323
- store_t_ptr,
1324
- accumulator.to(input_ptr.dtype.element_ty),
1325
- mask=offs_dim < dim_end,
1326
- )
1327
-
1328
-
1329
- def moe_sum_reduce_triton(
1330
- input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
1331
- ):
1332
- assert input.is_contiguous()
1333
- assert output.is_contiguous()
1334
-
1335
- token_num, topk_num, hidden_dim = input.shape
1336
- assert output.shape[0] == token_num and output.shape[1] == hidden_dim
1337
-
1338
- BLOCK_M = 1
1339
- BLOCK_DIM = 2048
1340
- NUM_STAGE = 1
1341
- num_warps = 8
1342
-
1343
- grid = (
1344
- triton.cdiv(token_num, BLOCK_M),
1345
- triton.cdiv(hidden_dim, BLOCK_DIM),
1346
- )
1347
-
1348
- _moe_sum_reduce_kernel[grid](
1349
- input,
1350
- *input.stride(),
1351
- output,
1352
- *output.stride(),
1353
- token_num=token_num,
1354
- topk_num=topk_num,
1355
- hidden_dim=hidden_dim,
1356
- routed_scaling_factor=routed_scaling_factor,
1357
- BLOCK_M=BLOCK_M,
1358
- BLOCK_DIM=BLOCK_DIM,
1359
- NUM_STAGE=NUM_STAGE,
1360
- num_warps=num_warps,
1361
- )
1362
- return
1363
-
1364
-
1365
325
  @torch.compile
1366
326
  def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
1367
327
  torch.sum(x, dim=1, out=out)