sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/test/send_one.py
ADDED
@@ -0,0 +1,88 @@
|
|
1
|
+
"""
|
2
|
+
Run one test prompt.
|
3
|
+
|
4
|
+
Usage:
|
5
|
+
python3 -m sglang.test.send_one
|
6
|
+
"""
|
7
|
+
|
8
|
+
import argparse
|
9
|
+
import json
|
10
|
+
|
11
|
+
import requests
|
12
|
+
|
13
|
+
|
14
|
+
def send_one_prompt(args):
|
15
|
+
if args.image:
|
16
|
+
args.prompt = (
|
17
|
+
"Human: Describe this image in a very short sentence.\n\nAssistant:"
|
18
|
+
)
|
19
|
+
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
|
20
|
+
else:
|
21
|
+
image_data = None
|
22
|
+
|
23
|
+
response = requests.post(
|
24
|
+
"http://localhost:30000/generate",
|
25
|
+
json={
|
26
|
+
"text": args.prompt,
|
27
|
+
"image_data": image_data,
|
28
|
+
"sampling_params": {
|
29
|
+
"temperature": args.temperature,
|
30
|
+
"max_new_tokens": args.max_new_tokens,
|
31
|
+
"frequency_penalty": args.frequency_penalty,
|
32
|
+
"presence_penalty": args.presence_penalty,
|
33
|
+
},
|
34
|
+
"return_logprob": args.return_logprob,
|
35
|
+
"stream": args.stream,
|
36
|
+
},
|
37
|
+
stream=args.stream,
|
38
|
+
)
|
39
|
+
|
40
|
+
if args.stream:
|
41
|
+
for chunk in response.iter_lines(decode_unicode=False):
|
42
|
+
chunk = chunk.decode("utf-8")
|
43
|
+
if chunk and chunk.startswith("data:"):
|
44
|
+
if chunk == "data: [DONE]":
|
45
|
+
break
|
46
|
+
ret = json.loads(chunk[5:].strip("\n"))
|
47
|
+
else:
|
48
|
+
ret = response.json()
|
49
|
+
|
50
|
+
latency = ret["meta_info"]["e2e_latency"]
|
51
|
+
|
52
|
+
if "spec_verify_ct" in ret["meta_info"]:
|
53
|
+
acc_length = (
|
54
|
+
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
acc_length = 1.0
|
58
|
+
|
59
|
+
speed = ret["meta_info"]["completion_tokens"] / latency
|
60
|
+
|
61
|
+
print(ret["text"])
|
62
|
+
print()
|
63
|
+
print(f"{acc_length=:.2f}")
|
64
|
+
print(f"{speed=:.2f} token/s")
|
65
|
+
|
66
|
+
return acc_length, speed
|
67
|
+
|
68
|
+
|
69
|
+
if __name__ == "__main__":
|
70
|
+
parser = argparse.ArgumentParser()
|
71
|
+
parser.add_argument("--temperature", type=float, default=0.0)
|
72
|
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
73
|
+
parser.add_argument("--frequency-penalty", type=float, default=0.0)
|
74
|
+
parser.add_argument("--presence-penalty", type=float, default=0.0)
|
75
|
+
parser.add_argument("--return-logprob", action="store_true")
|
76
|
+
parser.add_argument(
|
77
|
+
"--prompt",
|
78
|
+
type=str,
|
79
|
+
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
80
|
+
)
|
81
|
+
parser.add_argument(
|
82
|
+
"--image",
|
83
|
+
action="store_true",
|
84
|
+
)
|
85
|
+
parser.add_argument("--stream", action="store_true")
|
86
|
+
args = parser.parse_args()
|
87
|
+
|
88
|
+
send_one_prompt(args)
|
@@ -0,0 +1,361 @@
|
|
1
|
+
import itertools
|
2
|
+
import random
|
3
|
+
import unittest
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
9
|
+
grouped_gemm_triton,
|
10
|
+
post_reorder_triton_kernel,
|
11
|
+
pre_reorder_triton_kernel,
|
12
|
+
run_moe_ep_preproess,
|
13
|
+
silu_and_mul_triton_kernel,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.moe.topk import select_experts
|
16
|
+
|
17
|
+
|
18
|
+
# For test
|
19
|
+
def ep_moe(
|
20
|
+
hidden_states: torch.Tensor,
|
21
|
+
w1: torch.Tensor,
|
22
|
+
w2: torch.Tensor,
|
23
|
+
router_logits: torch.Tensor,
|
24
|
+
top_k: int,
|
25
|
+
renormalize: bool,
|
26
|
+
# ep config
|
27
|
+
num_experts: int = 256,
|
28
|
+
fp8_dtype: torch.types = torch.float8_e4m3fn,
|
29
|
+
num_experts_per_partition: int = 128,
|
30
|
+
start_expert_id: int = 0,
|
31
|
+
end_expert_id: int = 127,
|
32
|
+
use_grouped_topk: bool = False,
|
33
|
+
num_expert_group: Optional[int] = None,
|
34
|
+
topk_group: Optional[int] = None,
|
35
|
+
custom_routing_function: Optional[Callable] = None,
|
36
|
+
use_fp8_w8a8: bool = False,
|
37
|
+
w1_scale_inv: Optional[torch.Tensor] = None,
|
38
|
+
w2_scale_inv: Optional[torch.Tensor] = None,
|
39
|
+
block_shape: Optional[List[int]] = None,
|
40
|
+
):
|
41
|
+
use_blockwise_fp8 = block_shape is not None
|
42
|
+
topk_weights, topk_ids = select_experts(
|
43
|
+
hidden_states=hidden_states,
|
44
|
+
router_logits=router_logits,
|
45
|
+
top_k=top_k,
|
46
|
+
use_grouped_topk=use_grouped_topk,
|
47
|
+
renormalize=renormalize,
|
48
|
+
topk_group=topk_group,
|
49
|
+
num_expert_group=num_expert_group,
|
50
|
+
# correction_bias=correction_bias, #skip this in test
|
51
|
+
custom_routing_function=custom_routing_function,
|
52
|
+
)
|
53
|
+
|
54
|
+
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
|
55
|
+
|
56
|
+
gateup_input = torch.empty(
|
57
|
+
(int(hidden_states.shape[0] * top_k), hidden_states.shape[1]),
|
58
|
+
device=hidden_states.device,
|
59
|
+
dtype=(
|
60
|
+
fp8_dtype
|
61
|
+
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
62
|
+
else hidden_states.dtype
|
63
|
+
),
|
64
|
+
)
|
65
|
+
|
66
|
+
if use_fp8_w8a8 and not use_blockwise_fp8:
|
67
|
+
max_value = (
|
68
|
+
torch.max(hidden_states).repeat(num_experts_per_partition).to(torch.float32)
|
69
|
+
)
|
70
|
+
w1_input_scale = max_value / torch.finfo(fp8_dtype).max
|
71
|
+
else:
|
72
|
+
w1_input_scale = None
|
73
|
+
|
74
|
+
# PreReorder
|
75
|
+
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
76
|
+
hidden_states,
|
77
|
+
gateup_input,
|
78
|
+
src2dst,
|
79
|
+
topk_ids,
|
80
|
+
w1_input_scale,
|
81
|
+
start_expert_id,
|
82
|
+
end_expert_id,
|
83
|
+
top_k,
|
84
|
+
hidden_states.shape[1],
|
85
|
+
BLOCK_SIZE=512,
|
86
|
+
)
|
87
|
+
|
88
|
+
seg_indptr_cur_rank = seg_indptr[start_expert_id : end_expert_id + 2]
|
89
|
+
weight_indices_cur_rank = torch.arange(
|
90
|
+
0,
|
91
|
+
num_experts_per_partition,
|
92
|
+
device=hidden_states.device,
|
93
|
+
dtype=torch.int64,
|
94
|
+
)
|
95
|
+
|
96
|
+
# GroupGemm-0
|
97
|
+
gateup_output = torch.empty(
|
98
|
+
gateup_input.shape[0],
|
99
|
+
w1.shape[1],
|
100
|
+
device=hidden_states.device,
|
101
|
+
dtype=hidden_states.dtype,
|
102
|
+
)
|
103
|
+
|
104
|
+
gateup_output = grouped_gemm_triton(
|
105
|
+
a=gateup_input,
|
106
|
+
b=w1,
|
107
|
+
c=gateup_output,
|
108
|
+
batch_size=num_experts_per_partition,
|
109
|
+
weight_column_major=True,
|
110
|
+
seg_indptr=seg_indptr_cur_rank,
|
111
|
+
weight_indices=weight_indices_cur_rank,
|
112
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
113
|
+
scale_a=w1_input_scale,
|
114
|
+
scale_b=w1_scale_inv,
|
115
|
+
block_shape=block_shape,
|
116
|
+
)
|
117
|
+
|
118
|
+
# Act
|
119
|
+
down_input = torch.empty(
|
120
|
+
gateup_output.shape[0],
|
121
|
+
gateup_output.shape[1] // 2,
|
122
|
+
device=gateup_output.device,
|
123
|
+
dtype=(
|
124
|
+
fp8_dtype
|
125
|
+
if (use_fp8_w8a8 and not use_blockwise_fp8)
|
126
|
+
else hidden_states.dtype
|
127
|
+
),
|
128
|
+
)
|
129
|
+
if use_fp8_w8a8 and not use_blockwise_fp8:
|
130
|
+
w2_input_scale = torch.ones(
|
131
|
+
num_experts_per_partition,
|
132
|
+
dtype=torch.float32,
|
133
|
+
device=hidden_states.device,
|
134
|
+
)
|
135
|
+
else:
|
136
|
+
w2_input_scale = None
|
137
|
+
|
138
|
+
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
139
|
+
gateup_output,
|
140
|
+
down_input,
|
141
|
+
gateup_output.shape[1],
|
142
|
+
reorder_topk_ids,
|
143
|
+
w2_input_scale,
|
144
|
+
start_expert_id,
|
145
|
+
end_expert_id,
|
146
|
+
BLOCK_SIZE=512,
|
147
|
+
)
|
148
|
+
|
149
|
+
# GroupGemm-1
|
150
|
+
down_output = torch.empty(
|
151
|
+
down_input.shape[0],
|
152
|
+
w2.shape[1],
|
153
|
+
device=hidden_states.device,
|
154
|
+
dtype=hidden_states.dtype,
|
155
|
+
)
|
156
|
+
|
157
|
+
down_output = grouped_gemm_triton(
|
158
|
+
a=down_input,
|
159
|
+
b=w2,
|
160
|
+
c=down_output,
|
161
|
+
batch_size=num_experts_per_partition,
|
162
|
+
weight_column_major=True,
|
163
|
+
seg_indptr=seg_indptr_cur_rank,
|
164
|
+
weight_indices=weight_indices_cur_rank,
|
165
|
+
use_fp8_w8a8=use_fp8_w8a8,
|
166
|
+
scale_a=w2_input_scale,
|
167
|
+
scale_b=w2_scale_inv,
|
168
|
+
block_shape=block_shape,
|
169
|
+
)
|
170
|
+
|
171
|
+
# PostReorder
|
172
|
+
output = torch.empty_like(hidden_states)
|
173
|
+
post_reorder_triton_kernel[(hidden_states.size(0),)](
|
174
|
+
down_output,
|
175
|
+
output,
|
176
|
+
src2dst,
|
177
|
+
topk_ids,
|
178
|
+
topk_weights,
|
179
|
+
start_expert_id,
|
180
|
+
end_expert_id,
|
181
|
+
top_k,
|
182
|
+
hidden_states.size(1),
|
183
|
+
BLOCK_SIZE=512,
|
184
|
+
)
|
185
|
+
return output
|
186
|
+
|
187
|
+
|
188
|
+
# test util
|
189
|
+
def block_dequant(
|
190
|
+
x_q_block: torch.Tensor,
|
191
|
+
x_s: torch.Tensor,
|
192
|
+
block_size: List[int],
|
193
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
194
|
+
"""This function converts block-wise quantization to tensor-wise quantization.
|
195
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
196
|
+
and the block size.
|
197
|
+
The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
|
198
|
+
Note only float8 is supported for now.
|
199
|
+
"""
|
200
|
+
|
201
|
+
# process 3D tensor
|
202
|
+
if x_q_block.dim() == 3:
|
203
|
+
batch_size = x_q_block.size(0)
|
204
|
+
return torch.stack(
|
205
|
+
[block_dequant(x_q_block[b], x_s[b], block_size) for b in range(batch_size)]
|
206
|
+
)
|
207
|
+
|
208
|
+
block_n, block_k = block_size[0], block_size[1]
|
209
|
+
n, k = x_q_block.shape
|
210
|
+
n_tiles = (n + block_n - 1) // block_n
|
211
|
+
k_tiles = (k + block_k - 1) // block_k
|
212
|
+
assert n_tiles == x_s.shape[0]
|
213
|
+
assert k_tiles == x_s.shape[1]
|
214
|
+
|
215
|
+
x_dq_block = x_q_block.to(torch.float32)
|
216
|
+
|
217
|
+
x_dq_block_tiles = [
|
218
|
+
[
|
219
|
+
x_dq_block[
|
220
|
+
j * block_n : min((j + 1) * block_n, n),
|
221
|
+
i * block_k : min((i + 1) * block_k, k),
|
222
|
+
]
|
223
|
+
for i in range(k_tiles)
|
224
|
+
]
|
225
|
+
for j in range(n_tiles)
|
226
|
+
]
|
227
|
+
|
228
|
+
for i in range(k_tiles):
|
229
|
+
for j in range(n_tiles):
|
230
|
+
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
|
231
|
+
|
232
|
+
return x_dq_block
|
233
|
+
|
234
|
+
|
235
|
+
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
|
236
|
+
DTYPES = [torch.half, torch.bfloat16]
|
237
|
+
M = [1, 222, 1024, 2048]
|
238
|
+
N = [128, 1024, 2048]
|
239
|
+
K = [256, 4096, 5120]
|
240
|
+
E = [8, 16]
|
241
|
+
ep_size = [2, 4]
|
242
|
+
TOP_KS = [2, 4]
|
243
|
+
BLOCK_SIZE = [[128, 128]]
|
244
|
+
SEEDS = [0]
|
245
|
+
|
246
|
+
@classmethod
|
247
|
+
def setUpClass(cls):
|
248
|
+
if not torch.cuda.is_available():
|
249
|
+
raise unittest.SkipTest("CUDA is not available")
|
250
|
+
torch.set_default_device("cuda")
|
251
|
+
|
252
|
+
def _w8a8_block_fp8_ep_moe(
|
253
|
+
self, M, N, K, E, ep_size, topk, block_size, dtype, seed
|
254
|
+
):
|
255
|
+
torch.manual_seed(seed)
|
256
|
+
random.seed(seed)
|
257
|
+
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
|
258
|
+
factor_for_scale = 1e-2
|
259
|
+
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
260
|
+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
261
|
+
|
262
|
+
a = torch.randn((M, K), dtype=dtype) / 10
|
263
|
+
|
264
|
+
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=dtype) - 0.5) * 2 * fp8_max
|
265
|
+
w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
266
|
+
|
267
|
+
w2_fp32 = (torch.rand((E, K, N), dtype=dtype) - 0.5) * 2 * fp8_max
|
268
|
+
w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
269
|
+
|
270
|
+
block_n, block_k = block_size[0], block_size[1]
|
271
|
+
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
272
|
+
n_tiles_w2 = (K + block_n - 1) // block_n
|
273
|
+
k_tiles_w1 = (K + block_k - 1) // block_k
|
274
|
+
k_tiles_w2 = (N + block_k - 1) // block_k
|
275
|
+
|
276
|
+
w1_s = (
|
277
|
+
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
278
|
+
* factor_for_scale
|
279
|
+
)
|
280
|
+
w2_s = (
|
281
|
+
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
282
|
+
* factor_for_scale
|
283
|
+
)
|
284
|
+
|
285
|
+
w1_ref = block_dequant(w1, w1_s, block_size).to(dtype)
|
286
|
+
w2_ref = block_dequant(w2, w2_s, block_size).to(dtype)
|
287
|
+
|
288
|
+
score = torch.randn((M, E), dtype=dtype)
|
289
|
+
num_experts_per_partition = E // ep_size
|
290
|
+
cur_rank = random.randint(0, ep_size - 1)
|
291
|
+
start_id = cur_rank * num_experts_per_partition
|
292
|
+
end_id = start_id + num_experts_per_partition - 1
|
293
|
+
|
294
|
+
with torch.inference_mode():
|
295
|
+
out = ep_moe(
|
296
|
+
hidden_states=a,
|
297
|
+
w1=w1,
|
298
|
+
w2=w2,
|
299
|
+
router_logits=score,
|
300
|
+
top_k=topk,
|
301
|
+
renormalize=False,
|
302
|
+
use_fp8_w8a8=True,
|
303
|
+
w1_scale_inv=w1_s,
|
304
|
+
w2_scale_inv=w2_s,
|
305
|
+
block_shape=block_size,
|
306
|
+
num_experts=E,
|
307
|
+
num_experts_per_partition=num_experts_per_partition,
|
308
|
+
start_expert_id=start_id,
|
309
|
+
end_expert_id=end_id,
|
310
|
+
)
|
311
|
+
ref_out = ep_moe(
|
312
|
+
hidden_states=a,
|
313
|
+
w1=w1_ref,
|
314
|
+
w2=w2_ref,
|
315
|
+
router_logits=score,
|
316
|
+
top_k=topk,
|
317
|
+
renormalize=False,
|
318
|
+
use_fp8_w8a8=False,
|
319
|
+
w1_scale_inv=None,
|
320
|
+
w2_scale_inv=None,
|
321
|
+
block_shape=None,
|
322
|
+
num_experts=E,
|
323
|
+
num_experts_per_partition=num_experts_per_partition,
|
324
|
+
start_expert_id=start_id,
|
325
|
+
end_expert_id=end_id,
|
326
|
+
)
|
327
|
+
self.assertTrue(
|
328
|
+
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))
|
329
|
+
/ (torch.mean(torch.abs(ref_out.to(torch.float32))) + 1e-6)
|
330
|
+
< 0.06
|
331
|
+
)
|
332
|
+
|
333
|
+
def test_w8a8_block_fp8_ep_moe(self):
|
334
|
+
for params in itertools.product(
|
335
|
+
self.M,
|
336
|
+
self.N,
|
337
|
+
self.K,
|
338
|
+
self.E,
|
339
|
+
self.ep_size,
|
340
|
+
self.TOP_KS,
|
341
|
+
self.BLOCK_SIZE,
|
342
|
+
self.DTYPES,
|
343
|
+
self.SEEDS,
|
344
|
+
):
|
345
|
+
with self.subTest(
|
346
|
+
M=params[0],
|
347
|
+
N=params[1],
|
348
|
+
K=params[2],
|
349
|
+
E=params[3],
|
350
|
+
ep_size=params[4],
|
351
|
+
topk=params[5],
|
352
|
+
block_size=params[6],
|
353
|
+
dtype=params[7],
|
354
|
+
seed=params[8],
|
355
|
+
):
|
356
|
+
self._w8a8_block_fp8_ep_moe(*params)
|
357
|
+
torch.cuda.empty_cache()
|
358
|
+
|
359
|
+
|
360
|
+
if __name__ == "__main__":
|
361
|
+
unittest.main(verbosity=2)
|
sglang/test/test_programs.py
CHANGED
@@ -536,7 +536,7 @@ def test_hellaswag_select():
|
|
536
536
|
# Compute accuracy
|
537
537
|
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
538
538
|
print(f"{accuracy=}, {accuracy_gen=}")
|
539
|
-
assert np.abs(accuracy_gen - accuracy) < 0.
|
539
|
+
assert np.abs(accuracy_gen - accuracy) < 0.1
|
540
540
|
assert np.abs(latency_gen - latency) < 1
|
541
541
|
|
542
542
|
return accuracy, latency
|