sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
1
+ """
2
+ Run one test prompt.
3
+
4
+ Usage:
5
+ python3 -m sglang.test.send_one
6
+ """
7
+
8
+ import argparse
9
+ import json
10
+
11
+ import requests
12
+
13
+
14
+ def send_one_prompt(args):
15
+ if args.image:
16
+ args.prompt = (
17
+ "Human: Describe this image in a very short sentence.\n\nAssistant:"
18
+ )
19
+ image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
20
+ else:
21
+ image_data = None
22
+
23
+ response = requests.post(
24
+ "http://localhost:30000/generate",
25
+ json={
26
+ "text": args.prompt,
27
+ "image_data": image_data,
28
+ "sampling_params": {
29
+ "temperature": args.temperature,
30
+ "max_new_tokens": args.max_new_tokens,
31
+ "frequency_penalty": args.frequency_penalty,
32
+ "presence_penalty": args.presence_penalty,
33
+ },
34
+ "return_logprob": args.return_logprob,
35
+ "stream": args.stream,
36
+ },
37
+ stream=args.stream,
38
+ )
39
+
40
+ if args.stream:
41
+ for chunk in response.iter_lines(decode_unicode=False):
42
+ chunk = chunk.decode("utf-8")
43
+ if chunk and chunk.startswith("data:"):
44
+ if chunk == "data: [DONE]":
45
+ break
46
+ ret = json.loads(chunk[5:].strip("\n"))
47
+ else:
48
+ ret = response.json()
49
+
50
+ latency = ret["meta_info"]["e2e_latency"]
51
+
52
+ if "spec_verify_ct" in ret["meta_info"]:
53
+ acc_length = (
54
+ ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
55
+ )
56
+ else:
57
+ acc_length = 1.0
58
+
59
+ speed = ret["meta_info"]["completion_tokens"] / latency
60
+
61
+ print(ret["text"])
62
+ print()
63
+ print(f"{acc_length=:.2f}")
64
+ print(f"{speed=:.2f} token/s")
65
+
66
+ return acc_length, speed
67
+
68
+
69
+ if __name__ == "__main__":
70
+ parser = argparse.ArgumentParser()
71
+ parser.add_argument("--temperature", type=float, default=0.0)
72
+ parser.add_argument("--max-new-tokens", type=int, default=512)
73
+ parser.add_argument("--frequency-penalty", type=float, default=0.0)
74
+ parser.add_argument("--presence-penalty", type=float, default=0.0)
75
+ parser.add_argument("--return-logprob", action="store_true")
76
+ parser.add_argument(
77
+ "--prompt",
78
+ type=str,
79
+ default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
80
+ )
81
+ parser.add_argument(
82
+ "--image",
83
+ action="store_true",
84
+ )
85
+ parser.add_argument("--stream", action="store_true")
86
+ args = parser.parse_args()
87
+
88
+ send_one_prompt(args)
@@ -0,0 +1,361 @@
1
+ import itertools
2
+ import random
3
+ import unittest
4
+ from typing import Any, Callable, Dict, List, Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers.moe.ep_moe.kernels import (
9
+ grouped_gemm_triton,
10
+ post_reorder_triton_kernel,
11
+ pre_reorder_triton_kernel,
12
+ run_moe_ep_preproess,
13
+ silu_and_mul_triton_kernel,
14
+ )
15
+ from sglang.srt.layers.moe.topk import select_experts
16
+
17
+
18
+ # For test
19
+ def ep_moe(
20
+ hidden_states: torch.Tensor,
21
+ w1: torch.Tensor,
22
+ w2: torch.Tensor,
23
+ router_logits: torch.Tensor,
24
+ top_k: int,
25
+ renormalize: bool,
26
+ # ep config
27
+ num_experts: int = 256,
28
+ fp8_dtype: torch.types = torch.float8_e4m3fn,
29
+ num_experts_per_partition: int = 128,
30
+ start_expert_id: int = 0,
31
+ end_expert_id: int = 127,
32
+ use_grouped_topk: bool = False,
33
+ num_expert_group: Optional[int] = None,
34
+ topk_group: Optional[int] = None,
35
+ custom_routing_function: Optional[Callable] = None,
36
+ use_fp8_w8a8: bool = False,
37
+ w1_scale_inv: Optional[torch.Tensor] = None,
38
+ w2_scale_inv: Optional[torch.Tensor] = None,
39
+ block_shape: Optional[List[int]] = None,
40
+ ):
41
+ use_blockwise_fp8 = block_shape is not None
42
+ topk_weights, topk_ids = select_experts(
43
+ hidden_states=hidden_states,
44
+ router_logits=router_logits,
45
+ top_k=top_k,
46
+ use_grouped_topk=use_grouped_topk,
47
+ renormalize=renormalize,
48
+ topk_group=topk_group,
49
+ num_expert_group=num_expert_group,
50
+ # correction_bias=correction_bias, #skip this in test
51
+ custom_routing_function=custom_routing_function,
52
+ )
53
+
54
+ reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
55
+
56
+ gateup_input = torch.empty(
57
+ (int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
58
+ device=hidden_states.device,
59
+ dtype=(
60
+ fp8_dtype
61
+ if (use_fp8_w8a8 and not use_blockwise_fp8)
62
+ else hidden_states.dtype
63
+ ),
64
+ )
65
+
66
+ if use_fp8_w8a8 and not use_blockwise_fp8:
67
+ max_value = (
68
+ torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
69
+ )
70
+ w1_input_scale = max_value / torch.finfo(fp8_dtype).max
71
+ else:
72
+ w1_input_scale = None
73
+
74
+ # PreReorder
75
+ pre_reorder_triton_kernel[(hidden_states.shape[0],)](
76
+ hidden_states,
77
+ gateup_input,
78
+ src2dst,
79
+ topk_ids,
80
+ w1_input_scale,
81
+ start_expert_id,
82
+ end_expert_id,
83
+ top_k,
84
+ hidden_states.shape[1],
85
+ BLOCK_SIZE=512,
86
+ )
87
+
88
+ seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
89
+ weight_indices_cur_rank = torch.arange(
90
+ 0,
91
+ num_experts_per_partition,
92
+ device=hidden_states.device,
93
+ dtype=torch.int64,
94
+ )
95
+
96
+ # GroupGemm-0
97
+ gateup_output = torch.empty(
98
+ gateup_input.shape[0],
99
+ w1.shape[1],
100
+ device=hidden_states.device,
101
+ dtype=hidden_states.dtype,
102
+ )
103
+
104
+ gateup_output = grouped_gemm_triton(
105
+ a=gateup_input,
106
+ b=w1,
107
+ c=gateup_output,
108
+ batch_size=num_experts_per_partition,
109
+ weight_column_major=True,
110
+ seg_indptr=seg_indptr_cur_rank,
111
+ weight_indices=weight_indices_cur_rank,
112
+ use_fp8_w8a8=use_fp8_w8a8,
113
+ scale_a=w1_input_scale,
114
+ scale_b=w1_scale_inv,
115
+ block_shape=block_shape,
116
+ )
117
+
118
+ # Act
119
+ down_input = torch.empty(
120
+ gateup_output.shape[0],
121
+ gateup_output.shape[1] // 2,
122
+ device=gateup_output.device,
123
+ dtype=(
124
+ fp8_dtype
125
+ if (use_fp8_w8a8 and not use_blockwise_fp8)
126
+ else hidden_states.dtype
127
+ ),
128
+ )
129
+ if use_fp8_w8a8 and not use_blockwise_fp8:
130
+ w2_input_scale = torch.ones(
131
+ num_experts_per_partition,
132
+ dtype=torch.float32,
133
+ device=hidden_states.device,
134
+ )
135
+ else:
136
+ w2_input_scale = None
137
+
138
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
139
+ gateup_output,
140
+ down_input,
141
+ gateup_output.shape[1],
142
+ reorder_topk_ids,
143
+ w2_input_scale,
144
+ start_expert_id,
145
+ end_expert_id,
146
+ BLOCK_SIZE=512,
147
+ )
148
+
149
+ # GroupGemm-1
150
+ down_output = torch.empty(
151
+ down_input.shape[0],
152
+ w2.shape[1],
153
+ device=hidden_states.device,
154
+ dtype=hidden_states.dtype,
155
+ )
156
+
157
+ down_output = grouped_gemm_triton(
158
+ a=down_input,
159
+ b=w2,
160
+ c=down_output,
161
+ batch_size=num_experts_per_partition,
162
+ weight_column_major=True,
163
+ seg_indptr=seg_indptr_cur_rank,
164
+ weight_indices=weight_indices_cur_rank,
165
+ use_fp8_w8a8=use_fp8_w8a8,
166
+ scale_a=w2_input_scale,
167
+ scale_b=w2_scale_inv,
168
+ block_shape=block_shape,
169
+ )
170
+
171
+ # PostReorder
172
+ output = torch.empty_like(hidden_states)
173
+ post_reorder_triton_kernel[(hidden_states.size(0),)](
174
+ down_output,
175
+ output,
176
+ src2dst,
177
+ topk_ids,
178
+ topk_weights,
179
+ start_expert_id,
180
+ end_expert_id,
181
+ top_k,
182
+ hidden_states.size(1),
183
+ BLOCK_SIZE=512,
184
+ )
185
+ return output
186
+
187
+
188
+ # test util
189
+ def block_dequant(
190
+ x_q_block: torch.Tensor,
191
+ x_s: torch.Tensor,
192
+ block_size: List[int],
193
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
194
+ """This function converts block-wise quantization to tensor-wise quantization.
195
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
196
+ and the block size.
197
+ The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
198
+ Note only float8 is supported for now.
199
+ """
200
+
201
+ # process 3D tensor
202
+ if x_q_block.dim() == 3:
203
+ batch_size = x_q_block.size(0)
204
+ return torch.stack(
205
+ [block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
206
+ )
207
+
208
+ block_n, block_k = block_size[0], block_size[1]
209
+ n, k = x_q_block.shape
210
+ n_tiles = (n + block_n - 1) // block_n
211
+ k_tiles = (k + block_k - 1) // block_k
212
+ assert n_tiles == x_s.shape[0]
213
+ assert k_tiles == x_s.shape[1]
214
+
215
+ x_dq_block = x_q_block.to(torch.float32)
216
+
217
+ x_dq_block_tiles = [
218
+ [
219
+ x_dq_block[
220
+ j * block_n : min((j + 1) * block_n, n),
221
+ i * block_k : min((i + 1) * block_k, k),
222
+ ]
223
+ for i in range(k_tiles)
224
+ ]
225
+ for j in range(n_tiles)
226
+ ]
227
+
228
+ for i in range(k_tiles):
229
+ for j in range(n_tiles):
230
+ x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
231
+
232
+ return x_dq_block
233
+
234
+
235
+ class TestW8A8BlockFP8EPMoE(unittest.TestCase):
236
+ DTYPES = [torch.half, torch.bfloat16]
237
+ M = [1, 222, 1024, 2048]
238
+ N = [128, 1024, 2048]
239
+ K = [256, 4096, 5120]
240
+ E = [8, 16]
241
+ ep_size = [2, 4]
242
+ TOP_KS = [2, 4]
243
+ BLOCK_SIZE = [[128, 128]]
244
+ SEEDS = [0]
245
+
246
+ @classmethod
247
+ def setUpClass(cls):
248
+ if not torch.cuda.is_available():
249
+ raise unittest.SkipTest("CUDA is not available")
250
+ torch.set_default_device("cuda")
251
+
252
+ def _w8a8_block_fp8_ep_moe(
253
+ self, M, N, K, E, ep_size, topk, block_size, dtype, seed
254
+ ):
255
+ torch.manual_seed(seed)
256
+ random.seed(seed)
257
+ # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
258
+ factor_for_scale = 1e-2
259
+ fp8_info = torch.finfo(torch.float8_e4m3fn)
260
+ fp8_max, fp8_min = fp8_info.max, fp8_info.min
261
+
262
+ a = torch.randn((M, K), dtype=dtype) / 10
263
+
264
+ w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
265
+ w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
266
+
267
+ w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
268
+ w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
269
+
270
+ block_n, block_k = block_size[0], block_size[1]
271
+ n_tiles_w1 = (2 * N + block_n - 1) // block_n
272
+ n_tiles_w2 = (K + block_n - 1) // block_n
273
+ k_tiles_w1 = (K + block_k - 1) // block_k
274
+ k_tiles_w2 = (N + block_k - 1) // block_k
275
+
276
+ w1_s = (
277
+ torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
278
+ * factor_for_scale
279
+ )
280
+ w2_s = (
281
+ torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
282
+ * factor_for_scale
283
+ )
284
+
285
+ w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
286
+ w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
287
+
288
+ score = torch.randn((M, E), dtype=dtype)
289
+ num_experts_per_partition = E // ep_size
290
+ cur_rank = random.randint(0, ep_size - 1)
291
+ start_id = cur_rank * num_experts_per_partition
292
+ end_id = start_id + num_experts_per_partition - 1
293
+
294
+ with torch.inference_mode():
295
+ out = ep_moe(
296
+ hidden_states=a,
297
+ w1=w1,
298
+ w2=w2,
299
+ router_logits=score,
300
+ top_k=topk,
301
+ renormalize=False,
302
+ use_fp8_w8a8=True,
303
+ w1_scale_inv=w1_s,
304
+ w2_scale_inv=w2_s,
305
+ block_shape=block_size,
306
+ num_experts=E,
307
+ num_experts_per_partition=num_experts_per_partition,
308
+ start_expert_id=start_id,
309
+ end_expert_id=end_id,
310
+ )
311
+ ref_out = ep_moe(
312
+ hidden_states=a,
313
+ w1=w1_ref,
314
+ w2=w2_ref,
315
+ router_logits=score,
316
+ top_k=topk,
317
+ renormalize=False,
318
+ use_fp8_w8a8=False,
319
+ w1_scale_inv=None,
320
+ w2_scale_inv=None,
321
+ block_shape=None,
322
+ num_experts=E,
323
+ num_experts_per_partition=num_experts_per_partition,
324
+ start_expert_id=start_id,
325
+ end_expert_id=end_id,
326
+ )
327
+ self.assertTrue(
328
+ torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
329
+ / (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
330
+ < 0.06
331
+ )
332
+
333
+ def test_w8a8_block_fp8_ep_moe(self):
334
+ for params in itertools.product(
335
+ self.M,
336
+ self.N,
337
+ self.K,
338
+ self.E,
339
+ self.ep_size,
340
+ self.TOP_KS,
341
+ self.BLOCK_SIZE,
342
+ self.DTYPES,
343
+ self.SEEDS,
344
+ ):
345
+ with self.subTest(
346
+ M=params[0],
347
+ N=params[1],
348
+ K=params[2],
349
+ E=params[3],
350
+ ep_size=params[4],
351
+ topk=params[5],
352
+ block_size=params[6],
353
+ dtype=params[7],
354
+ seed=params[8],
355
+ ):
356
+ self._w8a8_block_fp8_ep_moe(*params)
357
+ torch.cuda.empty_cache()
358
+
359
+
360
+ if __name__ == "__main__":
361
+ unittest.main(verbosity=2)
@@ -536,7 +536,7 @@ def test_hellaswag_select():
536
536
  # Compute accuracy
537
537
  accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
538
538
  print(f"{accuracy=}, {accuracy_gen=}")
539
- assert np.abs(accuracy_gen - accuracy) < 0.05
539
+ assert np.abs(accuracy_gen - accuracy) < 0.1
540
540
  assert np.abs(latency_gen - latency) < 1
541
541
 
542
542
  return accuracy, latency