sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (256) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +89 -54
  3. sglang/bench_serving.py +437 -40
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/profiler.py +0 -1
  6. sglang/srt/configs/__init__.py +4 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/longcat_flash.py +104 -0
  9. sglang/srt/configs/model_config.py +37 -7
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +1 -1
  12. sglang/srt/connector/base_connector.py +1 -2
  13. sglang/srt/connector/redis.py +2 -2
  14. sglang/srt/connector/serde/__init__.py +1 -1
  15. sglang/srt/connector/serde/safe_serde.py +4 -3
  16. sglang/srt/custom_op.py +11 -1
  17. sglang/srt/debug_utils/dump_comparator.py +81 -44
  18. sglang/srt/debug_utils/dump_loader.py +97 -0
  19. sglang/srt/debug_utils/dumper.py +11 -3
  20. sglang/srt/debug_utils/text_comparator.py +73 -11
  21. sglang/srt/disaggregation/ascend/conn.py +75 -0
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +6 -4
  25. sglang/srt/disaggregation/fake/conn.py +1 -1
  26. sglang/srt/disaggregation/mini_lb.py +6 -420
  27. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  28. sglang/srt/disaggregation/nixl/conn.py +180 -16
  29. sglang/srt/disaggregation/prefill.py +6 -4
  30. sglang/srt/disaggregation/utils.py +5 -50
  31. sglang/srt/distributed/parallel_state.py +94 -58
  32. sglang/srt/entrypoints/engine.py +34 -14
  33. sglang/srt/entrypoints/http_server.py +172 -47
  34. sglang/srt/entrypoints/openai/protocol.py +90 -27
  35. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  36. sglang/srt/entrypoints/openai/serving_chat.py +82 -26
  37. sglang/srt/entrypoints/openai/serving_completions.py +25 -4
  38. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  39. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  40. sglang/srt/eplb/eplb_manager.py +28 -4
  41. sglang/srt/eplb/expert_distribution.py +55 -15
  42. sglang/srt/eplb/expert_location.py +8 -3
  43. sglang/srt/eplb/expert_location_updater.py +1 -1
  44. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  45. sglang/srt/function_call/ebnf_composer.py +11 -9
  46. sglang/srt/function_call/function_call_parser.py +2 -0
  47. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  48. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  49. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  50. sglang/srt/hf_transformers_utils.py +28 -7
  51. sglang/srt/layers/activation.py +44 -9
  52. sglang/srt/layers/attention/aiter_backend.py +93 -68
  53. sglang/srt/layers/attention/ascend_backend.py +381 -136
  54. sglang/srt/layers/attention/fla/chunk.py +242 -0
  55. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  56. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  57. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  58. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  59. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  60. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  61. sglang/srt/layers/attention/fla/index.py +37 -0
  62. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  63. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  64. sglang/srt/layers/attention/fla/op.py +66 -0
  65. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  66. sglang/srt/layers/attention/fla/utils.py +331 -0
  67. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  68. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  69. sglang/srt/layers/attention/flashinfer_backend.py +11 -6
  70. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
  71. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  72. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  73. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  74. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  75. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  76. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  77. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  78. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  79. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  80. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  81. sglang/srt/layers/communicator.py +45 -8
  82. sglang/srt/layers/layernorm.py +54 -12
  83. sglang/srt/layers/logits_processor.py +10 -3
  84. sglang/srt/layers/moe/__init__.py +2 -1
  85. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  86. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  87. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  88. sglang/srt/layers/moe/ep_moe/layer.py +111 -56
  89. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  90. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  94. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  101. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  102. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  103. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  104. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  105. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  106. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  107. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  108. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  109. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  110. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  111. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  112. sglang/srt/layers/moe/topk.py +43 -12
  113. sglang/srt/layers/moe/utils.py +6 -5
  114. sglang/srt/layers/quantization/awq.py +19 -7
  115. sglang/srt/layers/quantization/base_config.py +11 -6
  116. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  119. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  121. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
  122. sglang/srt/layers/quantization/fp8.py +78 -48
  123. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  124. sglang/srt/layers/quantization/fp8_utils.py +45 -31
  125. sglang/srt/layers/quantization/gptq.py +25 -17
  126. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  127. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  128. sglang/srt/layers/quantization/mxfp4.py +93 -68
  129. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  130. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  131. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  132. sglang/srt/layers/quantization/quark/utils.py +97 -0
  133. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  134. sglang/srt/layers/quantization/unquant.py +135 -47
  135. sglang/srt/layers/quantization/utils.py +13 -0
  136. sglang/srt/layers/quantization/w4afp8.py +60 -42
  137. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  138. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  139. sglang/srt/layers/rocm_linear_utils.py +44 -0
  140. sglang/srt/layers/rotary_embedding.py +28 -19
  141. sglang/srt/layers/sampler.py +29 -5
  142. sglang/srt/layers/utils.py +0 -14
  143. sglang/srt/lora/backend/base_backend.py +50 -8
  144. sglang/srt/lora/backend/triton_backend.py +90 -2
  145. sglang/srt/lora/layers.py +32 -0
  146. sglang/srt/lora/lora.py +4 -1
  147. sglang/srt/lora/lora_manager.py +35 -112
  148. sglang/srt/lora/mem_pool.py +24 -10
  149. sglang/srt/lora/utils.py +18 -9
  150. sglang/srt/managers/cache_controller.py +396 -365
  151. sglang/srt/managers/data_parallel_controller.py +30 -15
  152. sglang/srt/managers/detokenizer_manager.py +18 -2
  153. sglang/srt/managers/disagg_service.py +46 -0
  154. sglang/srt/managers/io_struct.py +190 -11
  155. sglang/srt/managers/mm_utils.py +6 -1
  156. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  157. sglang/srt/managers/schedule_batch.py +27 -44
  158. sglang/srt/managers/schedule_policy.py +4 -3
  159. sglang/srt/managers/scheduler.py +148 -122
  160. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  161. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  162. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  163. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  164. sglang/srt/managers/template_manager.py +3 -3
  165. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  166. sglang/srt/managers/tokenizer_manager.py +77 -480
  167. sglang/srt/managers/tp_worker.py +16 -4
  168. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  169. sglang/srt/mem_cache/allocator.py +1 -1
  170. sglang/srt/mem_cache/chunk_cache.py +1 -1
  171. sglang/srt/mem_cache/hicache_storage.py +53 -40
  172. sglang/srt/mem_cache/hiradix_cache.py +196 -104
  173. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  174. sglang/srt/mem_cache/memory_pool.py +395 -53
  175. sglang/srt/mem_cache/memory_pool_host.py +27 -19
  176. sglang/srt/mem_cache/radix_cache.py +6 -6
  177. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  178. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  179. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  180. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  181. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
  182. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  183. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  184. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
  185. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  186. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  187. sglang/srt/metrics/collector.py +484 -63
  188. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  189. sglang/srt/metrics/utils.py +48 -0
  190. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  191. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  192. sglang/srt/model_executor/forward_batch_info.py +72 -18
  193. sglang/srt/model_executor/model_runner.py +190 -32
  194. sglang/srt/model_loader/__init__.py +9 -3
  195. sglang/srt/model_loader/loader.py +33 -28
  196. sglang/srt/model_loader/utils.py +12 -0
  197. sglang/srt/model_loader/weight_utils.py +2 -1
  198. sglang/srt/models/deepseek_v2.py +323 -53
  199. sglang/srt/models/gemma3n_mm.py +1 -1
  200. sglang/srt/models/glm4_moe.py +10 -1
  201. sglang/srt/models/glm4v.py +4 -2
  202. sglang/srt/models/gpt_oss.py +7 -19
  203. sglang/srt/models/internvl.py +28 -0
  204. sglang/srt/models/llama4.py +9 -0
  205. sglang/srt/models/llama_eagle3.py +17 -0
  206. sglang/srt/models/longcat_flash.py +1026 -0
  207. sglang/srt/models/longcat_flash_nextn.py +699 -0
  208. sglang/srt/models/minicpmv.py +165 -3
  209. sglang/srt/models/mllama4.py +25 -0
  210. sglang/srt/models/opt.py +637 -0
  211. sglang/srt/models/qwen2.py +33 -3
  212. sglang/srt/models/qwen2_5_vl.py +91 -42
  213. sglang/srt/models/qwen2_moe.py +79 -14
  214. sglang/srt/models/qwen3.py +8 -2
  215. sglang/srt/models/qwen3_moe.py +39 -8
  216. sglang/srt/models/qwen3_next.py +1039 -0
  217. sglang/srt/models/qwen3_next_mtp.py +109 -0
  218. sglang/srt/models/torch_native_llama.py +1 -1
  219. sglang/srt/models/transformers.py +1 -1
  220. sglang/srt/multimodal/processors/base_processor.py +4 -2
  221. sglang/srt/multimodal/processors/glm4v.py +9 -9
  222. sglang/srt/multimodal/processors/internvl.py +141 -129
  223. sglang/srt/{conversation.py → parser/conversation.py} +38 -5
  224. sglang/srt/parser/harmony_parser.py +588 -0
  225. sglang/srt/parser/reasoning_parser.py +309 -0
  226. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  227. sglang/srt/sampling/sampling_batch_info.py +18 -15
  228. sglang/srt/server_args.py +307 -80
  229. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  230. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  231. sglang/srt/speculative/eagle_worker.py +216 -120
  232. sglang/srt/speculative/spec_info.py +5 -0
  233. sglang/srt/speculative/standalone_worker.py +109 -0
  234. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  235. sglang/srt/utils.py +96 -7
  236. sglang/srt/weight_sync/utils.py +1 -1
  237. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  238. sglang/test/few_shot_gsm8k.py +1 -0
  239. sglang/test/runners.py +4 -0
  240. sglang/test/test_cutlass_moe.py +24 -6
  241. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  242. sglang/test/test_disaggregation_utils.py +66 -0
  243. sglang/test/test_utils.py +25 -1
  244. sglang/utils.py +5 -0
  245. sglang/version.py +1 -1
  246. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
  247. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
  248. sglang/srt/disaggregation/launch_lb.py +0 -131
  249. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  250. sglang/srt/reasoning_parser.py +0 -553
  251. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  252. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  253. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  254. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  255. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  256. {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,799 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.layers.quantization.fp8_kernel import (
11
+ per_token_group_quant_fp8,
12
+ scaled_fp8_quant,
13
+ sglang_per_token_group_quant_fp8,
14
+ )
15
+ from sglang.srt.layers.quantization.int8_kernel import (
16
+ per_token_group_quant_int8,
17
+ per_token_quant_int8,
18
+ sglang_per_token_group_quant_int8,
19
+ )
20
+ from sglang.srt.utils import (
21
+ cpu_has_amx_support,
22
+ get_bool_env_var,
23
+ is_cpu,
24
+ is_cuda,
25
+ is_hip,
26
+ )
27
+
28
+ _is_hip = is_hip()
29
+ _is_cuda = is_cuda()
30
+ _is_cpu_amx_available = cpu_has_amx_support()
31
+ _is_cpu = is_cpu()
32
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
33
+
34
+ if _is_cuda:
35
+ pass
36
+ elif _is_cpu and _is_cpu_amx_available:
37
+ pass
38
+ elif _is_hip:
39
+ pass
40
+
41
+ padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
42
+
43
+
44
+ @triton.jit
45
+ def write_zeros_to_output(
46
+ c_ptr,
47
+ stride_cm,
48
+ stride_cn,
49
+ pid_n,
50
+ N,
51
+ offs_token,
52
+ token_mask,
53
+ BLOCK_SIZE_M,
54
+ BLOCK_SIZE_N,
55
+ compute_type,
56
+ ):
57
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
58
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
59
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
60
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
61
+ tl.store(c_ptrs, accumulator, mask=c_mask)
62
+
63
+
64
+ @triton.jit
65
+ def fused_moe_kernel_gptq_awq(
66
+ # Pointers to matrices
67
+ a_ptr,
68
+ b_ptr,
69
+ c_ptr,
70
+ b_scale_ptr,
71
+ b_zp_ptr,
72
+ topk_weights_ptr,
73
+ sorted_token_ids_ptr,
74
+ expert_ids_ptr,
75
+ num_tokens_post_padded_ptr,
76
+ # Matrix dimensions
77
+ N: tl.constexpr,
78
+ K: tl.constexpr,
79
+ EM,
80
+ num_valid_tokens,
81
+ # The stride variables represent how much to increase the ptr by when
82
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
83
+ # how much to increase `a_ptr` by to get the element one row down
84
+ # (A has M rows).
85
+ stride_am,
86
+ stride_ak,
87
+ stride_be,
88
+ stride_bk,
89
+ stride_bn,
90
+ stride_cm,
91
+ stride_cn,
92
+ stride_bse,
93
+ stride_bsk,
94
+ stride_bsn,
95
+ stride_bze,
96
+ stride_bzk,
97
+ stride_bzn,
98
+ group_size: tl.constexpr,
99
+ # Meta-parameters
100
+ BLOCK_SIZE_M: tl.constexpr,
101
+ BLOCK_SIZE_N: tl.constexpr,
102
+ BLOCK_SIZE_K: tl.constexpr,
103
+ GROUP_SIZE_M: tl.constexpr,
104
+ MUL_ROUTED_WEIGHT: tl.constexpr,
105
+ top_k: tl.constexpr,
106
+ compute_type: tl.constexpr,
107
+ has_zp: tl.constexpr,
108
+ use_int4_w4a16: tl.constexpr,
109
+ use_int8_w8a16: tl.constexpr,
110
+ even_Ks: tl.constexpr,
111
+ ):
112
+ """
113
+ Implements the fused computation for a Mixture of Experts (MOE) using
114
+ token and expert matrices.
115
+ Key Parameters:
116
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
117
+ be any shape representing batches and K is the feature dimension of
118
+ each token.
119
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
120
+ the number of experts, K is the input feature dimension, and N is
121
+ the output feature dimension.
122
+ - C: The output cache tensor with shape (M, topk, N), where M is the
123
+ total number of tokens post padding, topk is the number of times
124
+ each token is repeated, and N is the output feature dimension.
125
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
126
+ repeated topk times and arranged by the expert index they are
127
+ assigned to.
128
+ - expert_ids: A tensor containing the indices of the expert for each
129
+ block. It determines which expert matrix from B should be used for
130
+ each block in A.
131
+ This kernel performs the multiplication of a token by its corresponding
132
+ expert matrix as determined by `expert_ids`. The sorting of
133
+ `sorted_token_ids` by expert index and padding ensures divisibility by
134
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
135
+ multiplication across different blocks processed by the same expert.
136
+ """
137
+ # -----------------------------------------------------------
138
+ # Map program ids `pid` to the block of C it should compute.
139
+ # This is done in a grouped ordering to promote L2 data reuse.
140
+ pid = tl.program_id(axis=0)
141
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
142
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
143
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
144
+ group_id = pid // num_pid_in_group
145
+ first_pid_m = group_id * GROUP_SIZE_M
146
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
147
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
148
+ pid_n = (pid % num_pid_in_group) // group_size_m
149
+
150
+ # ----------------------------------------------------------
151
+ # Create pointers for the first blocks of A and B.
152
+ # We will advance this pointer as we move in the K direction
153
+ # and accumulate
154
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
155
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
156
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
157
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
158
+ return
159
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
160
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
161
+ token_mask = offs_token < num_valid_tokens
162
+
163
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
164
+ if off_experts == -1:
165
+ # -----------------------------------------------------------
166
+ # Write back zeros to the output when the expert is not
167
+ # in the current expert parallel rank.
168
+ write_zeros_to_output(
169
+ c_ptr,
170
+ stride_cm,
171
+ stride_cn,
172
+ pid_n,
173
+ N,
174
+ offs_token,
175
+ token_mask,
176
+ BLOCK_SIZE_M,
177
+ BLOCK_SIZE_N,
178
+ compute_type,
179
+ )
180
+ return
181
+
182
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
183
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
184
+ a_ptrs = a_ptr + (
185
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
186
+ )
187
+
188
+ if use_int4_w4a16:
189
+ b_ptrs = (
190
+ b_ptr
191
+ + off_experts * stride_be
192
+ + (offs_k[:, None] // 2) * stride_bk
193
+ + offs_bn[None, :] * stride_bn
194
+ )
195
+ b_shifter = (offs_k[:, None] % 2) * 4
196
+ elif use_int8_w8a16:
197
+ b_ptrs = (
198
+ b_ptr
199
+ + off_experts * stride_be
200
+ + offs_k[:, None] * stride_bk
201
+ + offs_bn[None, :] * stride_bn
202
+ )
203
+
204
+ if not has_zp and use_int4_w4a16:
205
+ b_zp_num = 8
206
+ if not has_zp and use_int8_w8a16:
207
+ b_zp_num = 128
208
+ elif has_zp and use_int4_w4a16:
209
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
210
+
211
+ # -----------------------------------------------------------
212
+ # Iterate to compute a block of the C matrix.
213
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
214
+ # of fp32 values for higher accuracy.
215
+ # `accumulator` will be converted back to fp16 after the loop.
216
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
217
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
218
+ # Load the next block of A and B, generate a mask by checking the
219
+ # K dimension.
220
+
221
+ if not even_Ks:
222
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
223
+ k_other = 0.0
224
+ else:
225
+ k_mask = None
226
+ k_other = None
227
+
228
+ a = tl.load(
229
+ a_ptrs,
230
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
231
+ other=0.0,
232
+ )
233
+ b = tl.load(b_ptrs)
234
+ if use_int4_w4a16:
235
+ b = (b >> b_shifter) & 0xF
236
+
237
+ b_scale_ptrs = (
238
+ b_scale_ptr
239
+ + off_experts * stride_bse
240
+ + offs_bn[None, :] * stride_bsn
241
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
242
+ )
243
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
244
+ b_scale = b_scale.to(tl.float32)
245
+
246
+ if has_zp and use_int4_w4a16:
247
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
248
+ b_zp_ptrs = (
249
+ b_zp_ptr
250
+ + off_experts * stride_bze
251
+ + (offs_bn[None, :] // 2) * stride_bzn
252
+ + offs_k_true * stride_bzk
253
+ )
254
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
255
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
256
+ b_zp = b_zp.to(tl.float32)
257
+ elif has_zp and use_int8_w8a16:
258
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
259
+ b_zp_ptrs = (
260
+ b_zp_ptr
261
+ + off_experts * stride_bze
262
+ + offs_bn[None, :] * stride_bzn
263
+ + offs_k_true * stride_bzk
264
+ )
265
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
266
+ b_zp = b_zp.to(tl.float32)
267
+
268
+ # We accumulate along the K dimension.
269
+ if has_zp:
270
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
271
+ else:
272
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
273
+ accumulator = tl.dot(a, b, acc=accumulator)
274
+
275
+ # Advance the ptrs to the next K block.
276
+ a_ptrs += BLOCK_SIZE_K * stride_ak
277
+ if use_int4_w4a16:
278
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
279
+ else:
280
+ b_ptrs += BLOCK_SIZE_K * stride_bk
281
+
282
+ if MUL_ROUTED_WEIGHT:
283
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
284
+ accumulator = accumulator * moe_weight[:, None]
285
+
286
+ accumulator = accumulator.to(compute_type)
287
+ # -----------------------------------------------------------
288
+ # Write back the block of the output
289
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
290
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
291
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
292
+ tl.store(c_ptrs, accumulator, mask=c_mask)
293
+
294
+
295
+ @triton.jit
296
+ def fused_moe_kernel(
297
+ # Pointers to matrices
298
+ a_ptr,
299
+ b_ptr,
300
+ bias_ptr,
301
+ c_ptr,
302
+ a_scale_ptr,
303
+ b_scale_ptr,
304
+ topk_weights_ptr,
305
+ sorted_token_ids_ptr,
306
+ expert_ids_ptr,
307
+ num_tokens_post_padded_ptr,
308
+ # Matrix dimensions
309
+ N,
310
+ K,
311
+ EM,
312
+ num_valid_tokens,
313
+ # The stride variables represent how much to increase the ptr by when
314
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
315
+ # how much to increase `a_ptr` by to get the element one row down
316
+ # (A has M rows).
317
+ stride_am,
318
+ stride_ak,
319
+ stride_be,
320
+ stride_bk,
321
+ stride_bn,
322
+ stride_bias_e,
323
+ stride_bias_n,
324
+ stride_cm,
325
+ stride_cn,
326
+ stride_asm,
327
+ stride_ask,
328
+ stride_bse,
329
+ stride_bsk,
330
+ stride_bsn,
331
+ # Block size for block-wise quantization
332
+ group_n: tl.constexpr,
333
+ group_k: tl.constexpr,
334
+ # Meta-parameters
335
+ BLOCK_SIZE_M: tl.constexpr,
336
+ BLOCK_SIZE_N: tl.constexpr,
337
+ BLOCK_SIZE_K: tl.constexpr,
338
+ GROUP_SIZE_M: tl.constexpr,
339
+ MUL_ROUTED_WEIGHT: tl.constexpr,
340
+ top_k: tl.constexpr,
341
+ compute_type: tl.constexpr,
342
+ use_fp8_w8a8: tl.constexpr,
343
+ use_int8_w8a8: tl.constexpr,
344
+ use_int8_w8a16: tl.constexpr,
345
+ per_channel_quant: tl.constexpr,
346
+ even_Ks: tl.constexpr,
347
+ ):
348
+ """
349
+ Implements the fused computation for a Mixture of Experts (MOE) using
350
+ token and expert matrices.
351
+
352
+ Key Parameters:
353
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
354
+ be any shape representing batches and K is the feature dimension of
355
+ each token.
356
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
357
+ the number of experts, K is the input feature dimension, and N is
358
+ the output feature dimension.
359
+ - C: The output cache tensor with shape (M, topk, N), where M is the
360
+ total number of tokens post padding, topk is the number of times
361
+ each token is repeated, and N is the output feature dimension.
362
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
363
+ repeated topk times and arranged by the expert index they are
364
+ assigned to.
365
+ - expert_ids: A tensor containing the indices of the expert for each
366
+ block. It determines which expert matrix from B should be used for
367
+ each block in A.
368
+
369
+ This kernel performs the multiplication of a token by its corresponding
370
+ expert matrix as determined by `expert_ids`. The sorting of
371
+ `sorted_token_ids` by expert index and padding ensures divisibility by
372
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
373
+ multiplication across different blocks processed by the same expert.
374
+ """
375
+ # -----------------------------------------------------------
376
+ # Map program ids `pid` to the block of C it should compute.
377
+ # This is done in a grouped ordering to promote L2 data reuse.
378
+ pid = tl.program_id(axis=0)
379
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
380
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
381
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
382
+ group_id = pid // num_pid_in_group
383
+ first_pid_m = group_id * GROUP_SIZE_M
384
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
385
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
386
+ pid_n = (pid % num_pid_in_group) // group_size_m
387
+
388
+ # ----------------------------------------------------------
389
+ # Create pointers for the first blocks of A and B.
390
+ # We will advance this pointer as we move in the K direction
391
+ # and accumulate
392
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
393
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
394
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
395
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
396
+ return
397
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
398
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
399
+ offs_token = offs_token.to(tl.int64)
400
+ token_mask = offs_token < num_valid_tokens
401
+
402
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
403
+
404
+ if off_experts == -1:
405
+ # -----------------------------------------------------------
406
+ # Write back zeros to the output when the expert is not
407
+ # in the current expert parallel rank.
408
+ write_zeros_to_output(
409
+ c_ptr,
410
+ stride_cm,
411
+ stride_cn,
412
+ pid_n,
413
+ N,
414
+ offs_token,
415
+ token_mask,
416
+ BLOCK_SIZE_M,
417
+ BLOCK_SIZE_N,
418
+ compute_type,
419
+ )
420
+ return
421
+
422
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
423
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
424
+ a_ptrs = a_ptr + (
425
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
426
+ )
427
+
428
+ b_ptrs = (
429
+ b_ptr
430
+ + off_experts * stride_be
431
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
432
+ )
433
+ if bias_ptr is not None:
434
+ bias = tl.load(
435
+ bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
436
+ )
437
+ if use_int8_w8a16:
438
+ b_scale_ptrs = (
439
+ b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
440
+ )
441
+ b_scale = tl.load(b_scale_ptrs)
442
+
443
+ if use_fp8_w8a8 or use_int8_w8a8:
444
+ # block-wise
445
+ if group_k > 0 and group_n > 0:
446
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
447
+ offs_bsn = offs_bn // group_n
448
+ b_scale_ptrs = (
449
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
450
+ )
451
+ # channel-wise
452
+ elif per_channel_quant:
453
+ b_scale_ptrs = (
454
+ b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
455
+ )
456
+ b_scale = tl.load(b_scale_ptrs)
457
+ # Load per-token scale for activations
458
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
459
+ a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
460
+ # tensor-wise
461
+ else:
462
+ a_scale = tl.load(a_scale_ptr)
463
+ b_scale = tl.load(b_scale_ptr + off_experts)
464
+
465
+ # -----------------------------------------------------------
466
+ # Iterate to compute a block of the C matrix.
467
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
468
+ # of fp32 values for higher accuracy.
469
+ # `accumulator` will be converted back to fp16 after the loop.
470
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
471
+
472
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
473
+ # Load the next block of A and B, generate a mask by checking the
474
+ # K dimension.
475
+ if even_Ks:
476
+ a = tl.load(
477
+ a_ptrs,
478
+ mask=token_mask[:, None],
479
+ other=0.0,
480
+ )
481
+ b = tl.load(b_ptrs)
482
+ else:
483
+ a = tl.load(
484
+ a_ptrs,
485
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
486
+ other=0.0,
487
+ )
488
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
489
+
490
+ # We accumulate along the K dimension.
491
+ if use_int8_w8a16:
492
+ accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
493
+ elif use_fp8_w8a8 or use_int8_w8a8:
494
+ if group_k > 0 and group_n > 0:
495
+ k_start = k * BLOCK_SIZE_K
496
+ offs_ks = k_start // group_k
497
+ a_scale = tl.load(
498
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
499
+ )
500
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
501
+
502
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
503
+ else:
504
+ if use_fp8_w8a8:
505
+ accumulator = tl.dot(a, b, acc=accumulator)
506
+ else:
507
+ accumulator += tl.dot(a, b)
508
+ else:
509
+ accumulator += tl.dot(a, b)
510
+ # Advance the ptrs to the next K block.
511
+ a_ptrs += BLOCK_SIZE_K * stride_ak
512
+ b_ptrs += BLOCK_SIZE_K * stride_bk
513
+
514
+ if use_int8_w8a16:
515
+ accumulator *= b_scale
516
+ elif use_fp8_w8a8 or use_int8_w8a8:
517
+ if group_k == 0 or group_n == 0:
518
+ accumulator *= a_scale * b_scale
519
+
520
+ if bias_ptr is not None:
521
+ accumulator += bias
522
+
523
+ if MUL_ROUTED_WEIGHT:
524
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
525
+ accumulator *= moe_weight[:, None]
526
+
527
+ accumulator = accumulator.to(compute_type)
528
+ # -----------------------------------------------------------
529
+ # Write back the block of the output
530
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
531
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
532
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
533
+ tl.store(c_ptrs, accumulator, mask=c_mask)
534
+
535
+
536
+ def invoke_fused_moe_kernel(
537
+ A: torch.Tensor,
538
+ B: torch.Tensor,
539
+ bias: Optional[torch.Tensor],
540
+ C: torch.Tensor,
541
+ A_scale: Optional[torch.Tensor],
542
+ B_scale: Optional[torch.Tensor],
543
+ B_zp: Optional[torch.Tensor],
544
+ topk_weights: torch.Tensor,
545
+ topk_ids: torch.Tensor,
546
+ sorted_token_ids: torch.Tensor,
547
+ expert_ids: torch.Tensor,
548
+ num_tokens_post_padded: torch.Tensor,
549
+ mul_routed_weight: bool,
550
+ top_k: int,
551
+ config: Dict[str, Any],
552
+ compute_type: tl.dtype,
553
+ use_fp8_w8a8: bool,
554
+ use_int8_w8a8: bool,
555
+ use_int8_w8a16: bool,
556
+ use_int4_w4a16: bool,
557
+ per_channel_quant: bool,
558
+ block_shape: Optional[List[int]] = None,
559
+ no_combine: bool = False,
560
+ ) -> None:
561
+ assert topk_weights.stride(1) == 1
562
+ assert sorted_token_ids.stride(0) == 1
563
+
564
+ padded_size = 0
565
+ if use_fp8_w8a8:
566
+ assert B_scale is not None
567
+ if block_shape is None:
568
+ # activation tensor-wise fp8 quantization, dynamic or static
569
+ padded_size = padding_size
570
+ # activations apply per-token quantization when weights apply per-channel quantization by default
571
+ A, A_scale = scaled_fp8_quant(
572
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
573
+ )
574
+ else:
575
+ # activation block-wise fp8 quantization
576
+ assert len(block_shape) == 2
577
+ block_n, block_k = block_shape[0], block_shape[1]
578
+ if _is_cuda:
579
+ A, A_scale = sglang_per_token_group_quant_fp8(A, block_k)
580
+ else:
581
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
582
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
583
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
584
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
585
+ elif use_int8_w8a8:
586
+ assert B_scale is not None
587
+ if block_shape is None:
588
+ # activation channel-wise int8 quantization
589
+ assert (
590
+ per_channel_quant
591
+ ), "int8 quantization only supports channel-wise quantization except for block-wise quantization"
592
+ A, A_scale = per_token_quant_int8(A)
593
+ else:
594
+ # activation block-wise int8 quantization
595
+ assert len(block_shape) == 2
596
+ block_n, block_k = block_shape[0], block_shape[1]
597
+ if _is_cuda:
598
+ A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
599
+ else:
600
+ A, A_scale = per_token_group_quant_int8(A, block_k)
601
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
602
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
603
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
604
+ elif use_int8_w8a16 or use_int4_w4a16:
605
+ assert B_scale is not None
606
+ assert block_shape is None or block_shape[0] == 0
607
+ else:
608
+ assert A_scale is None
609
+ assert B_scale is None
610
+
611
+ grid = lambda META: (
612
+ triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
613
+ * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
614
+ )
615
+
616
+ K = B.shape[2] - padded_size
617
+ if K % config["BLOCK_SIZE_K"] == 0:
618
+ even_Ks = True
619
+ else:
620
+ even_Ks = False
621
+
622
+ if (
623
+ (use_int8_w8a16 or use_int4_w4a16)
624
+ and block_shape is not None
625
+ and block_shape[1] > 0
626
+ ):
627
+ assert B_scale is not None and B_scale.ndim == 3
628
+ assert B_zp is None or B_zp.ndim == 3
629
+ assert bias is None
630
+ fused_moe_kernel_gptq_awq[grid](
631
+ A,
632
+ B,
633
+ C,
634
+ B_scale,
635
+ B_zp,
636
+ topk_weights,
637
+ sorted_token_ids,
638
+ expert_ids,
639
+ num_tokens_post_padded,
640
+ B.shape[1],
641
+ A.shape[1],
642
+ sorted_token_ids.shape[0],
643
+ topk_ids.numel(),
644
+ A.stride(0),
645
+ A.stride(1),
646
+ B.stride(0),
647
+ B.stride(2),
648
+ B.stride(1),
649
+ C.stride(1),
650
+ C.stride(2),
651
+ B_scale.stride(0),
652
+ B_scale.stride(2),
653
+ B_scale.stride(1),
654
+ B_zp.stride(0) if B_zp is not None else 0,
655
+ B_zp.stride(2) if B_zp is not None else 0,
656
+ B_zp.stride(1) if B_zp is not None else 0,
657
+ group_size=block_shape[1],
658
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
659
+ top_k=top_k,
660
+ compute_type=compute_type,
661
+ has_zp=B_zp is not None,
662
+ use_int4_w4a16=use_int4_w4a16,
663
+ use_int8_w8a16=use_int8_w8a16,
664
+ even_Ks=even_Ks,
665
+ **config,
666
+ )
667
+
668
+ else:
669
+
670
+ fused_moe_kernel[grid](
671
+ A,
672
+ B,
673
+ bias,
674
+ C,
675
+ A_scale,
676
+ B_scale,
677
+ topk_weights,
678
+ sorted_token_ids,
679
+ expert_ids,
680
+ num_tokens_post_padded,
681
+ B.shape[1],
682
+ B.shape[2] - padded_size,
683
+ sorted_token_ids.shape[0],
684
+ topk_ids.numel(),
685
+ A.stride(0),
686
+ A.stride(1),
687
+ B.stride(0),
688
+ B.stride(2),
689
+ B.stride(1),
690
+ bias.stride(0) if bias is not None else 0,
691
+ bias.stride(1) if bias is not None else 0,
692
+ C.stride(1),
693
+ C.stride(2),
694
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
695
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
696
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
697
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
698
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
699
+ 0 if block_shape is None else block_shape[0],
700
+ 0 if block_shape is None else block_shape[1],
701
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
702
+ top_k=top_k,
703
+ compute_type=compute_type,
704
+ use_fp8_w8a8=use_fp8_w8a8,
705
+ use_int8_w8a8=use_int8_w8a8,
706
+ use_int8_w8a16=use_int8_w8a16,
707
+ per_channel_quant=per_channel_quant,
708
+ even_Ks=even_Ks,
709
+ **config,
710
+ )
711
+
712
+
713
+ # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
714
+ @triton.jit
715
+ def _moe_sum_reduce_kernel(
716
+ input_ptr,
717
+ input_stride_0,
718
+ input_stride_1,
719
+ input_stride_2,
720
+ output_ptr,
721
+ output_stride_0,
722
+ output_stride_1,
723
+ token_num: int,
724
+ topk_num: int,
725
+ hidden_dim: int,
726
+ routed_scaling_factor: tl.constexpr,
727
+ BLOCK_M: tl.constexpr,
728
+ BLOCK_DIM: tl.constexpr,
729
+ NUM_STAGE: tl.constexpr,
730
+ ):
731
+ input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
732
+ input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
733
+ output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
734
+
735
+ token_block_id = tl.program_id(0)
736
+ dim_block_id = tl.program_id(1)
737
+
738
+ offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
739
+ offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)
740
+
741
+ mask_token = offs_token < token_num
742
+ mask_dim = offs_dim < hidden_dim
743
+
744
+ base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]
745
+
746
+ accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
747
+
748
+ for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
749
+ tile = tl.load(
750
+ base_ptrs + i * input_stride_1,
751
+ mask=mask_token[:, None] & mask_dim[None, :],
752
+ other=0.0,
753
+ )
754
+ accumulator += tile.to(tl.float32)
755
+ accumulator *= routed_scaling_factor
756
+
757
+ # -------- Write back --------
758
+ store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
759
+ tl.store(
760
+ store_ptrs,
761
+ accumulator.to(input_ptr.dtype.element_ty),
762
+ mask=mask_token[:, None] & mask_dim[None, :],
763
+ )
764
+
765
+
766
+ def moe_sum_reduce_triton(
767
+ input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
768
+ ):
769
+ assert input.is_contiguous()
770
+ assert output.is_contiguous()
771
+
772
+ token_num, topk_num, hidden_dim = input.shape
773
+ assert output.shape[0] == token_num and output.shape[1] == hidden_dim
774
+
775
+ BLOCK_M = 1
776
+ BLOCK_DIM = 2048
777
+ NUM_STAGE = 1
778
+ num_warps = 16
779
+
780
+ grid = (
781
+ triton.cdiv(token_num, BLOCK_M),
782
+ triton.cdiv(hidden_dim, BLOCK_DIM),
783
+ )
784
+
785
+ _moe_sum_reduce_kernel[grid](
786
+ input,
787
+ *input.stride(),
788
+ output,
789
+ *output.stride(),
790
+ token_num=token_num,
791
+ topk_num=topk_num,
792
+ hidden_dim=hidden_dim,
793
+ routed_scaling_factor=routed_scaling_factor,
794
+ BLOCK_M=BLOCK_M,
795
+ BLOCK_DIM=BLOCK_DIM,
796
+ NUM_STAGE=NUM_STAGE,
797
+ num_warps=num_warps,
798
+ )
799
+ return