sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +77 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/detokenizer_manager.py +2 -0
- sglang/srt/managers/io_struct.py +40 -9
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/scheduler.py +69 -21
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +48 -10
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -0
- sglang/srt/models/llama.py +11 -0
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +60 -2
- sglang/srt/openai_api/protocol.py +48 -0
- sglang/srt/server.py +26 -3
- sglang/srt/server_args.py +24 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/METADATA +3 -3
- sglang-0.4.1.post3.dist-info/RECORD +305 -0
- sglang-0.4.1.post1.dist-info/RECORD +0 -195
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 64,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 4
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 32,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 4
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 4
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 16,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 3
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 16,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 4
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 32,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 128,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 1,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 1,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 2
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|
@@ -28,7 +28,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
28
28
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
29
29
|
|
30
30
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
31
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
|
32
31
|
from sglang.srt.layers.quantization.base_config import (
|
33
32
|
QuantizationConfig,
|
34
33
|
QuantizeMethodBase,
|
@@ -273,6 +272,19 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
273
272
|
def process_weights_after_loading(self, layer: Module) -> None:
|
274
273
|
# Block quant doesn't need to process weights after loading
|
275
274
|
if self.block_quant:
|
275
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
276
|
+
if is_hip():
|
277
|
+
# activation_scheme: dynamic
|
278
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
279
|
+
weight=layer.weight,
|
280
|
+
weight_scale=layer.weight_scale_inv,
|
281
|
+
input_scale=None,
|
282
|
+
)
|
283
|
+
layer.weight = torch.nn.Parameter(weight, require_grad=False)
|
284
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
285
|
+
weight_scale, require_grad=False
|
286
|
+
)
|
287
|
+
layer.input_scale = None
|
276
288
|
return
|
277
289
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
278
290
|
# If checkpoint not serialized fp8, quantize the weights.
|
@@ -370,7 +382,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
370
382
|
weight=layer.weight,
|
371
383
|
block_size=self.quant_config.weight_block_size,
|
372
384
|
weight_scale=layer.weight_scale_inv,
|
373
|
-
input_scale=
|
385
|
+
input_scale=None,
|
374
386
|
bias=bias,
|
375
387
|
)
|
376
388
|
|
@@ -548,8 +560,36 @@ class Fp8MoEMethod:
|
|
548
560
|
layer.w2_input_scale = None
|
549
561
|
|
550
562
|
def process_weights_after_loading(self, layer: Module) -> None:
|
563
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
564
|
+
padding_size, # Avoid circular import
|
565
|
+
)
|
566
|
+
|
551
567
|
# Block quant doesn't need to process weights after loading
|
552
568
|
if self.block_quant:
|
569
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
570
|
+
if is_hip():
|
571
|
+
# activation_scheme: dynamic
|
572
|
+
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
573
|
+
weight=layer.w13_weight,
|
574
|
+
weight_scale=layer.w13_weight_scale_inv,
|
575
|
+
input_scale=None,
|
576
|
+
)
|
577
|
+
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
578
|
+
weight=layer.w2_weight,
|
579
|
+
weight_scale=layer.w2_weight_scale_inv,
|
580
|
+
input_scale=None,
|
581
|
+
)
|
582
|
+
# Reset the parameter
|
583
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
584
|
+
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
585
|
+
w13_weight_scale, requires_grad=False
|
586
|
+
)
|
587
|
+
layer.w13_input_scale = None
|
588
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
589
|
+
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
590
|
+
w2_weight_scale, requires_grad=False
|
591
|
+
)
|
592
|
+
layer.w2_input_scale = None
|
553
593
|
return
|
554
594
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
555
595
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
@@ -12,12 +12,23 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
|
15
|
+
import functools
|
16
|
+
import json
|
17
|
+
import logging
|
18
|
+
import os
|
19
|
+
from typing import Any, Dict, List, Optional, Tuple
|
16
20
|
|
17
21
|
import torch
|
18
22
|
import triton
|
19
23
|
import triton.language as tl
|
20
24
|
|
25
|
+
from sglang.srt.utils import get_device_name, is_hip
|
26
|
+
|
27
|
+
is_hip_ = is_hip()
|
28
|
+
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
21
32
|
|
22
33
|
@triton.jit
|
23
34
|
def _per_token_group_quant_fp8(
|
@@ -65,7 +76,7 @@ def per_token_group_quant_fp8(
|
|
65
76
|
x: torch.Tensor,
|
66
77
|
group_size: int,
|
67
78
|
eps: float = 1e-10,
|
68
|
-
dtype: torch.dtype =
|
79
|
+
dtype: torch.dtype = fp8_type_,
|
69
80
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
70
81
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
71
82
|
|
@@ -87,9 +98,13 @@ def per_token_group_quant_fp8(
|
|
87
98
|
assert x.is_contiguous(), "`x` is not contiguous"
|
88
99
|
|
89
100
|
finfo = torch.finfo(dtype)
|
90
|
-
fp8_min = finfo.min
|
91
101
|
fp8_max = finfo.max
|
92
102
|
|
103
|
+
if is_hip_:
|
104
|
+
fp8_max = 224.0
|
105
|
+
|
106
|
+
fp8_min = -fp8_max
|
107
|
+
|
93
108
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
94
109
|
M = x.numel() // group_size
|
95
110
|
N = group_size
|
@@ -205,6 +220,48 @@ def _w8a8_block_fp8_matmul(
|
|
205
220
|
tl.store(c_ptrs, c, mask=c_mask)
|
206
221
|
|
207
222
|
|
223
|
+
@functools.lru_cache
|
224
|
+
def get_w8a8_block_fp8_configs(
|
225
|
+
N: int, K: int, block_n: int, block_k: int
|
226
|
+
) -> Optional[Dict[int, Any]]:
|
227
|
+
"""
|
228
|
+
Return optimized configurations for the w8a8 block fp8 kernel.
|
229
|
+
|
230
|
+
The return value will be a dictionary that maps an irregular grid of
|
231
|
+
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
|
232
|
+
kernel on a given batch size bs, the closest batch size in the grid should
|
233
|
+
be picked and the associated configuration chosen to invoke the kernel.
|
234
|
+
"""
|
235
|
+
|
236
|
+
# First look up if an optimized configuration is available in the configs
|
237
|
+
# directory
|
238
|
+
device_name = get_device_name().replace(" ", "_")
|
239
|
+
json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json"
|
240
|
+
|
241
|
+
config_file_path = os.path.join(
|
242
|
+
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
|
243
|
+
)
|
244
|
+
if os.path.exists(config_file_path):
|
245
|
+
with open(config_file_path) as f:
|
246
|
+
logger.info(
|
247
|
+
"Using configuration from %s for W8A8 Block FP8 kernel.",
|
248
|
+
config_file_path,
|
249
|
+
)
|
250
|
+
# If a configuration has been found, return it
|
251
|
+
return {int(key): val for key, val in json.load(f).items()}
|
252
|
+
|
253
|
+
# If no optimized configuration is available, we will use the default
|
254
|
+
# configuration
|
255
|
+
logger.warning(
|
256
|
+
(
|
257
|
+
"Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! "
|
258
|
+
"Config file not found at %s"
|
259
|
+
),
|
260
|
+
config_file_path,
|
261
|
+
)
|
262
|
+
return None
|
263
|
+
|
264
|
+
|
208
265
|
def w8a8_block_fp8_matmul(
|
209
266
|
A: torch.Tensor,
|
210
267
|
B: torch.Tensor,
|
@@ -245,17 +302,22 @@ def w8a8_block_fp8_matmul(
|
|
245
302
|
C_shape = A.shape[:-1] + (N,)
|
246
303
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
247
304
|
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
305
|
+
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
306
|
+
if configs:
|
307
|
+
# If an optimal configuration map has been found, look up the
|
308
|
+
# optimal config
|
309
|
+
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
310
|
+
else:
|
311
|
+
# Default config
|
312
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
|
313
|
+
config = {
|
314
|
+
"BLOCK_SIZE_M": 64,
|
315
|
+
"BLOCK_SIZE_N": block_size[0],
|
316
|
+
"BLOCK_SIZE_K": block_size[1],
|
317
|
+
"GROUP_SIZE_M": 32,
|
318
|
+
"num_warps": 4,
|
319
|
+
"num_stages": 3,
|
320
|
+
}
|
259
321
|
|
260
322
|
def grid(META):
|
261
323
|
return (
|
@@ -283,10 +345,7 @@ def w8a8_block_fp8_matmul(
|
|
283
345
|
As.stride(-1),
|
284
346
|
Bs.stride(1),
|
285
347
|
Bs.stride(0),
|
286
|
-
|
287
|
-
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
288
|
-
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
289
|
-
GROUP_SIZE_M=8,
|
348
|
+
**config,
|
290
349
|
)
|
291
350
|
|
292
351
|
return C
|
@@ -7,6 +7,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
7
7
|
per_token_group_quant_fp8,
|
8
8
|
w8a8_block_fp8_matmul,
|
9
9
|
)
|
10
|
+
from sglang.srt.utils import is_hip
|
11
|
+
|
12
|
+
is_hip_ = is_hip()
|
10
13
|
|
11
14
|
|
12
15
|
def normalize_e4m3fn_to_e4m3fnuz(
|
@@ -63,8 +66,11 @@ def input_to_float8(
|
|
63
66
|
finfo = torch.finfo(dtype)
|
64
67
|
min_val, max_val = x.aminmax()
|
65
68
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
66
|
-
|
67
|
-
|
69
|
+
fp8_max = finfo.max
|
70
|
+
if is_hip_:
|
71
|
+
fp8_max = 224.0
|
72
|
+
scale = fp8_max / amax
|
73
|
+
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
|
68
74
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
69
75
|
|
70
76
|
|
@@ -181,6 +181,8 @@ class DetokenizerManager:
|
|
181
181
|
finished_reasons=recv_obj.finished_reasons,
|
182
182
|
output_strs=output_strs,
|
183
183
|
prompt_tokens=recv_obj.prompt_tokens,
|
184
|
+
origin_input_ids=recv_obj.origin_input_ids,
|
185
|
+
output_ids=recv_obj.output_ids,
|
184
186
|
completion_tokens=recv_obj.completion_tokens,
|
185
187
|
cached_tokens=recv_obj.cached_tokens,
|
186
188
|
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -21,10 +21,20 @@ from dataclasses import dataclass
|
|
21
21
|
from enum import Enum
|
22
22
|
from typing import Dict, List, Optional, Tuple, Union
|
23
23
|
|
24
|
+
import torch
|
25
|
+
|
24
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
25
27
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
26
28
|
|
27
29
|
|
30
|
+
@dataclass
|
31
|
+
class SessionParams:
|
32
|
+
id: Optional[str] = None
|
33
|
+
rid: Optional[str] = None
|
34
|
+
offset: Optional[int] = None
|
35
|
+
replace: Optional[bool] = None
|
36
|
+
|
37
|
+
|
28
38
|
@dataclass
|
29
39
|
class GenerateReqInput:
|
30
40
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
@@ -56,10 +66,8 @@ class GenerateReqInput:
|
|
56
66
|
# LoRA related
|
57
67
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
68
|
|
59
|
-
# Session
|
60
|
-
|
61
|
-
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
62
|
-
] = None
|
69
|
+
# Session info for continual prompting
|
70
|
+
session_params: Optional[Union[List[Dict], Dict]] = None
|
63
71
|
|
64
72
|
def normalize_batch_and_arguments(self):
|
65
73
|
if (
|
@@ -221,9 +229,8 @@ class TokenizedGenerateReqInput:
|
|
221
229
|
# The input embeds
|
222
230
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
223
231
|
|
224
|
-
# Session
|
225
|
-
|
226
|
-
session_rid: Optional[str] = None
|
232
|
+
# Session info for continual prompting
|
233
|
+
session_params: Optional[SessionParams] = None
|
227
234
|
|
228
235
|
|
229
236
|
@dataclass
|
@@ -316,7 +323,9 @@ class BatchTokenIDOut:
|
|
316
323
|
decoded_texts: List[str]
|
317
324
|
decode_ids: List[int]
|
318
325
|
read_offsets: List[int]
|
319
|
-
# Only used when
|
326
|
+
# Only used when --return-token-ids` is set
|
327
|
+
origin_input_ids: Optional[List[int]]
|
328
|
+
# Only used when `--skip-tokenizer-init` or `--return-token-ids` is set
|
320
329
|
output_ids: Optional[List[int]]
|
321
330
|
# Detokenization configs
|
322
331
|
skip_special_tokens: List[bool]
|
@@ -347,10 +356,18 @@ class BatchStrOut:
|
|
347
356
|
# The output decoded strings
|
348
357
|
output_strs: List[str]
|
349
358
|
|
359
|
+
# The token ids
|
360
|
+
origin_input_ids: Optional[List[int]]
|
361
|
+
output_ids: Optional[List[int]]
|
362
|
+
|
350
363
|
# Token counts
|
364
|
+
# real input and output tokens can be get from
|
365
|
+
# origin_input_ids and output_ids by enabling --return_token_ids
|
366
|
+
# TODO (Shuai): Rename this to clarify the meaning.
|
351
367
|
prompt_tokens: List[int]
|
352
368
|
completion_tokens: List[int]
|
353
369
|
cached_tokens: List[int]
|
370
|
+
|
354
371
|
# Logprobs
|
355
372
|
input_token_logprobs_val: List[float]
|
356
373
|
input_token_logprobs_idx: List[int]
|
@@ -407,6 +424,18 @@ class UpdateWeightsFromDistributedReqOutput:
|
|
407
424
|
message: str
|
408
425
|
|
409
426
|
|
427
|
+
@dataclass
|
428
|
+
class UpdateWeightsFromTensorReqInput:
|
429
|
+
name: str
|
430
|
+
tensor: torch.Tensor
|
431
|
+
|
432
|
+
|
433
|
+
@dataclass
|
434
|
+
class UpdateWeightsFromTensorReqOutput:
|
435
|
+
success: bool
|
436
|
+
message: str
|
437
|
+
|
438
|
+
|
410
439
|
@dataclass
|
411
440
|
class InitWeightsUpdateGroupReqInput:
|
412
441
|
# The master address
|
@@ -454,6 +483,7 @@ class ProfileReq(Enum):
|
|
454
483
|
@dataclass
|
455
484
|
class OpenSessionReqInput:
|
456
485
|
capacity_of_str_len: int
|
486
|
+
session_id: Optional[str] = None
|
457
487
|
|
458
488
|
|
459
489
|
@dataclass
|
@@ -463,4 +493,5 @@ class CloseSessionReqInput:
|
|
463
493
|
|
464
494
|
@dataclass
|
465
495
|
class OpenSessionReqOutput:
|
466
|
-
session_id: str
|
496
|
+
session_id: Optional[str]
|
497
|
+
success: bool
|