sglang 0.4.1.post1__py3-none-any.whl → 0.4.1.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +1 -0
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/layers/attention/__init__.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +54 -41
- sglang/srt/layers/logits_processor.py +30 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -26
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +42 -2
- sglang/srt/layers/quantization/fp8_kernel.py +77 -18
- sglang/srt/layers/quantization/fp8_utils.py +8 -2
- sglang/srt/managers/detokenizer_manager.py +2 -0
- sglang/srt/managers/io_struct.py +40 -9
- sglang/srt/managers/schedule_batch.py +22 -15
- sglang/srt/managers/scheduler.py +69 -21
- sglang/srt/managers/session_controller.py +102 -27
- sglang/srt/managers/tokenizer_manager.py +48 -10
- sglang/srt/managers/tp_worker.py +7 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -0
- sglang/srt/model_executor/forward_batch_info.py +42 -3
- sglang/srt/model_executor/model_runner.py +4 -0
- sglang/srt/models/llama.py +11 -0
- sglang/srt/models/llama_eagle.py +132 -0
- sglang/srt/openai_api/adapter.py +60 -2
- sglang/srt/openai_api/protocol.py +48 -0
- sglang/srt/server.py +26 -3
- sglang/srt/server_args.py +24 -30
- sglang/srt/speculative/spec_info.py +19 -0
- sglang/srt/utils.py +62 -0
- sglang/version.py +1 -1
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/METADATA +3 -3
- sglang-0.4.1.post3.dist-info/RECORD +305 -0
- sglang-0.4.1.post1.dist-info/RECORD +0 -195
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post1.dist-info → sglang-0.4.1.post3.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
|
|
15
15
|
import json
|
16
16
|
import logging
|
17
17
|
from enum import IntEnum, auto
|
18
|
-
from typing import List, Optional, Union
|
18
|
+
from typing import List, Optional, Set, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from transformers import PretrainedConfig
|
@@ -47,6 +47,7 @@ class ModelConfig:
|
|
47
47
|
self.model_path = model_path
|
48
48
|
self.revision = revision
|
49
49
|
self.quantization = quantization
|
50
|
+
|
50
51
|
# Parse args
|
51
52
|
self.model_override_args = json.loads(model_override_args)
|
52
53
|
self.hf_config = get_config(
|
@@ -130,7 +131,8 @@ class ModelConfig:
|
|
130
131
|
# Veirfy quantization
|
131
132
|
self._verify_quantization()
|
132
133
|
|
133
|
-
#
|
134
|
+
# Cache attributes
|
135
|
+
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
134
136
|
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
135
137
|
|
136
138
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
@@ -271,6 +273,13 @@ class ModelConfig:
|
|
271
273
|
self.quantization,
|
272
274
|
)
|
273
275
|
|
276
|
+
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
277
|
+
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
278
|
+
if eos_ids:
|
279
|
+
# it can be either int or list of int
|
280
|
+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
281
|
+
return eos_ids
|
282
|
+
|
274
283
|
|
275
284
|
def get_hf_text_config(config: PretrainedConfig):
|
276
285
|
"""Get the "sub" config relevant to llm for multi modal models.
|
@@ -8,8 +8,9 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
|
8
8
|
"""
|
9
9
|
|
10
10
|
import os
|
11
|
+
from dataclasses import dataclass
|
11
12
|
from enum import Enum, auto
|
12
|
-
from typing import TYPE_CHECKING, List
|
13
|
+
from typing import TYPE_CHECKING, List, Union
|
13
14
|
|
14
15
|
import torch
|
15
16
|
import triton
|
@@ -38,12 +39,25 @@ class WrapperDispatch(Enum):
|
|
38
39
|
CROSS_ATTENTION = auto()
|
39
40
|
|
40
41
|
|
42
|
+
@dataclass
|
43
|
+
class DecodeMetadata:
|
44
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
45
|
+
|
46
|
+
|
47
|
+
@dataclass
|
48
|
+
class PrefillMetadata:
|
49
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
50
|
+
use_ragged: bool
|
51
|
+
extend_no_prefix: bool
|
52
|
+
|
53
|
+
|
41
54
|
class FlashInferAttnBackend(AttentionBackend):
|
42
55
|
"""Flashinfer attention kernels."""
|
43
56
|
|
44
57
|
def __init__(self, model_runner: ModelRunner):
|
45
58
|
super().__init__()
|
46
59
|
|
60
|
+
# Parse constants
|
47
61
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
48
62
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
49
63
|
num_attention_heads=model_runner.model_config.num_attention_heads
|
@@ -52,7 +66,6 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
52
66
|
model_runner.tp_size
|
53
67
|
),
|
54
68
|
)
|
55
|
-
|
56
69
|
self.max_context_len = model_runner.model_config.context_len
|
57
70
|
|
58
71
|
assert not (
|
@@ -120,8 +133,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
120
133
|
)
|
121
134
|
|
122
135
|
# Other metadata
|
123
|
-
self.forward_metadata = None
|
124
|
-
self.
|
136
|
+
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
137
|
+
self.decode_cuda_graph_metadata = {}
|
125
138
|
|
126
139
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
127
140
|
if forward_batch.forward_mode.is_decode():
|
@@ -129,10 +142,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
129
142
|
forward_batch.req_pool_indices,
|
130
143
|
forward_batch.seq_lens,
|
131
144
|
forward_batch.seq_lens_sum,
|
132
|
-
decode_wrappers=
|
145
|
+
decode_wrappers=self.decode_wrappers,
|
133
146
|
encoder_lens=forward_batch.encoder_lens,
|
134
147
|
)
|
135
|
-
self.forward_metadata = (self.decode_wrappers
|
148
|
+
self.forward_metadata = DecodeMetadata(self.decode_wrappers)
|
136
149
|
else:
|
137
150
|
prefix_lens = forward_batch.extend_prefix_lens
|
138
151
|
|
@@ -149,11 +162,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
149
162
|
forward_batch.seq_lens,
|
150
163
|
forward_batch.seq_lens_sum,
|
151
164
|
prefix_lens,
|
165
|
+
prefill_wrappers=self.prefill_wrappers_paged,
|
152
166
|
use_ragged=use_ragged,
|
153
167
|
encoder_lens=forward_batch.encoder_lens,
|
154
168
|
)
|
155
|
-
|
156
|
-
|
169
|
+
self.forward_metadata = PrefillMetadata(
|
170
|
+
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
171
|
+
)
|
157
172
|
|
158
173
|
def init_cuda_graph_state(self, max_bs: int):
|
159
174
|
cuda_graph_kv_indices = torch.zeros(
|
@@ -194,8 +209,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
194
209
|
decode_wrappers=decode_wrappers,
|
195
210
|
encoder_lens=encoder_lens,
|
196
211
|
)
|
197
|
-
self.
|
198
|
-
self.forward_metadata = (decode_wrappers
|
212
|
+
self.decode_cuda_graph_metadata[bs] = decode_wrappers
|
213
|
+
self.forward_metadata = DecodeMetadata(decode_wrappers)
|
199
214
|
|
200
215
|
def init_forward_metadata_replay_cuda_graph(
|
201
216
|
self,
|
@@ -209,7 +224,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
209
224
|
req_pool_indices[:bs],
|
210
225
|
seq_lens[:bs],
|
211
226
|
seq_lens_sum,
|
212
|
-
decode_wrappers=self.
|
227
|
+
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
213
228
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
214
229
|
)
|
215
230
|
|
@@ -225,18 +240,16 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
225
240
|
forward_batch: ForwardBatch,
|
226
241
|
save_kv_cache=True,
|
227
242
|
):
|
228
|
-
prefill_wrapper_paged = self.
|
243
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
229
244
|
self._get_wrapper_idx(layer)
|
230
245
|
]
|
231
|
-
|
232
|
-
use_ragged, extend_no_prefix = self.forward_metadata
|
233
246
|
cache_loc = (
|
234
247
|
forward_batch.out_cache_loc
|
235
248
|
if not layer.is_cross_attention
|
236
249
|
else forward_batch.encoder_out_cache_loc
|
237
250
|
)
|
238
251
|
|
239
|
-
if not use_ragged:
|
252
|
+
if not self.forward_metadata.use_ragged:
|
240
253
|
if k is not None:
|
241
254
|
assert v is not None
|
242
255
|
if save_kv_cache:
|
@@ -260,7 +273,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
260
273
|
logits_soft_cap=layer.logit_cap,
|
261
274
|
)
|
262
275
|
|
263
|
-
if extend_no_prefix:
|
276
|
+
if self.forward_metadata.extend_no_prefix:
|
264
277
|
o = o1
|
265
278
|
else:
|
266
279
|
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
@@ -287,7 +300,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
287
300
|
forward_batch: ForwardBatch,
|
288
301
|
save_kv_cache=True,
|
289
302
|
):
|
290
|
-
decode_wrapper = self.forward_metadata[
|
303
|
+
decode_wrapper = self.forward_metadata.decode_wrappers[
|
304
|
+
self._get_wrapper_idx(layer)
|
305
|
+
]
|
291
306
|
cache_loc = (
|
292
307
|
forward_batch.out_cache_loc
|
293
308
|
if not layer.is_cross_attention
|
@@ -322,7 +337,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
322
337
|
|
323
338
|
class FlashInferIndicesUpdaterDecode:
|
324
339
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
325
|
-
# Constants
|
340
|
+
# Parse Constants
|
326
341
|
self.num_qo_heads = (
|
327
342
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
328
343
|
)
|
@@ -340,9 +355,8 @@ class FlashInferIndicesUpdaterDecode:
|
|
340
355
|
self.kv_indptr = attn_backend.kv_indptr
|
341
356
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
342
357
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
343
|
-
self.decode_wrappers = attn_backend.decode_wrappers
|
344
358
|
|
345
|
-
# Dispatch
|
359
|
+
# Dispatch the update function
|
346
360
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
347
361
|
self.update = self.update_sliding_window
|
348
362
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
@@ -356,7 +370,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
356
370
|
req_pool_indices: torch.Tensor,
|
357
371
|
seq_lens: torch.Tensor,
|
358
372
|
seq_lens_sum: int,
|
359
|
-
decode_wrappers: List,
|
373
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
360
374
|
encoder_lens: torch.Tensor,
|
361
375
|
):
|
362
376
|
# Keep the signature for type checking. It will be assigned during runtime.
|
@@ -367,7 +381,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
367
381
|
req_pool_indices: torch.Tensor,
|
368
382
|
seq_lens: torch.Tensor,
|
369
383
|
seq_lens_sum: int,
|
370
|
-
decode_wrappers: List,
|
384
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
371
385
|
encoder_lens: torch.Tensor,
|
372
386
|
):
|
373
387
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
@@ -385,11 +399,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
385
399
|
req_pool_indices: torch.Tensor,
|
386
400
|
seq_lens: torch.Tensor,
|
387
401
|
seq_lens_sum: int,
|
388
|
-
decode_wrappers: List,
|
402
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
389
403
|
encoder_lens: torch.Tensor,
|
390
404
|
):
|
391
|
-
decode_wrappers = decode_wrappers or self.decode_wrappers
|
392
|
-
|
393
405
|
for wrapper_id in range(2):
|
394
406
|
if wrapper_id == 0:
|
395
407
|
# Sliding window attention
|
@@ -419,11 +431,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
419
431
|
req_pool_indices: torch.Tensor,
|
420
432
|
seq_lens: torch.Tensor,
|
421
433
|
seq_lens_sum: int,
|
422
|
-
decode_wrappers: List,
|
434
|
+
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
423
435
|
encoder_lens: torch.Tensor,
|
424
436
|
):
|
425
|
-
decode_wrappers = decode_wrappers or self.decode_wrappers
|
426
|
-
|
427
437
|
for wrapper_id in range(2):
|
428
438
|
if wrapper_id == 0:
|
429
439
|
# Normal attention
|
@@ -446,7 +456,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
446
456
|
|
447
457
|
def call_begin_forward(
|
448
458
|
self,
|
449
|
-
wrapper,
|
459
|
+
wrapper: BatchDecodeWithPagedKVCacheWrapper,
|
450
460
|
req_pool_indices: torch.Tensor,
|
451
461
|
paged_kernel_lens: torch.Tensor,
|
452
462
|
paged_kernel_lens_sum: int,
|
@@ -486,7 +496,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
486
496
|
|
487
497
|
class FlashInferIndicesUpdaterPrefill:
|
488
498
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
489
|
-
# Constants
|
499
|
+
# Parse Constants
|
490
500
|
self.num_qo_heads = (
|
491
501
|
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
492
502
|
)
|
@@ -505,10 +515,9 @@ class FlashInferIndicesUpdaterPrefill:
|
|
505
515
|
self.kv_last_page_len = attn_backend.kv_last_page_len
|
506
516
|
self.qo_indptr = attn_backend.qo_indptr
|
507
517
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
508
|
-
self.
|
509
|
-
self.wrappers_paged = attn_backend.prefill_wrappers_paged
|
518
|
+
self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged
|
510
519
|
|
511
|
-
# Dispatch
|
520
|
+
# Dispatch the update function
|
512
521
|
if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW:
|
513
522
|
self.update = self.update_sliding_window
|
514
523
|
elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION:
|
@@ -523,6 +532,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
523
532
|
seq_lens: torch.Tensor,
|
524
533
|
seq_lens_sum: int,
|
525
534
|
prefix_lens: torch.Tensor,
|
535
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
526
536
|
use_ragged: bool,
|
527
537
|
encoder_lens: torch.Tensor,
|
528
538
|
):
|
@@ -535,6 +545,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
535
545
|
seq_lens: torch.Tensor,
|
536
546
|
seq_lens_sum: int,
|
537
547
|
prefix_lens: torch.Tensor,
|
548
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
538
549
|
use_ragged: bool,
|
539
550
|
encoder_lens: torch.Tensor,
|
540
551
|
):
|
@@ -546,8 +557,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
546
557
|
paged_kernel_lens_sum = seq_lens_sum
|
547
558
|
|
548
559
|
self.call_begin_forward(
|
549
|
-
self.
|
550
|
-
|
560
|
+
self.prefill_wrapper_ragged,
|
561
|
+
prefill_wrappers[0],
|
551
562
|
req_pool_indices,
|
552
563
|
paged_kernel_lens,
|
553
564
|
paged_kernel_lens_sum,
|
@@ -565,6 +576,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
565
576
|
seq_lens: torch.Tensor,
|
566
577
|
seq_lens_sum: int,
|
567
578
|
prefix_lens: torch.Tensor,
|
579
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
568
580
|
use_ragged: bool,
|
569
581
|
encoder_lens: torch.Tensor,
|
570
582
|
):
|
@@ -584,8 +596,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
584
596
|
kv_start_idx = seq_lens - paged_kernel_lens
|
585
597
|
|
586
598
|
self.call_begin_forward(
|
587
|
-
self.
|
588
|
-
|
599
|
+
self.prefill_wrapper_ragged,
|
600
|
+
prefill_wrappers[wrapper_id],
|
589
601
|
req_pool_indices,
|
590
602
|
paged_kernel_lens,
|
591
603
|
paged_kernel_lens_sum,
|
@@ -603,6 +615,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
603
615
|
seq_lens: torch.Tensor,
|
604
616
|
seq_lens_sum: int,
|
605
617
|
prefix_lens: torch.Tensor,
|
618
|
+
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
606
619
|
use_ragged: bool,
|
607
620
|
encoder_lens: torch.Tensor,
|
608
621
|
):
|
@@ -619,8 +632,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
619
632
|
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
620
633
|
|
621
634
|
self.call_begin_forward(
|
622
|
-
self.
|
623
|
-
|
635
|
+
self.prefill_wrapper_ragged,
|
636
|
+
prefill_wrappers[wrapper_id],
|
624
637
|
req_pool_indices,
|
625
638
|
paged_kernel_lens,
|
626
639
|
paged_kernel_lens_sum,
|
@@ -634,8 +647,8 @@ class FlashInferIndicesUpdaterPrefill:
|
|
634
647
|
|
635
648
|
def call_begin_forward(
|
636
649
|
self,
|
637
|
-
wrapper_ragged,
|
638
|
-
wrapper_paged,
|
650
|
+
wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper,
|
651
|
+
wrapper_paged: BatchPrefillWithPagedKVCacheWrapper,
|
639
652
|
req_pool_indices: torch.Tensor,
|
640
653
|
paged_kernel_lens: torch.Tensor,
|
641
654
|
paged_kernel_lens_sum: int,
|
@@ -24,7 +24,11 @@ from vllm.distributed import (
|
|
24
24
|
)
|
25
25
|
|
26
26
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
27
|
-
from sglang.srt.model_executor.forward_batch_info import
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
28
|
+
CaptureHiddenMode,
|
29
|
+
ForwardBatch,
|
30
|
+
ForwardMode,
|
31
|
+
)
|
28
32
|
|
29
33
|
|
30
34
|
@dataclasses.dataclass
|
@@ -46,6 +50,10 @@ class LogitsProcessorOutput:
|
|
46
50
|
output_top_logprobs_val: List = None
|
47
51
|
output_top_logprobs_idx: List = None
|
48
52
|
|
53
|
+
# Used by speculative decoding (EAGLE)
|
54
|
+
# The output of transformer layers
|
55
|
+
hidden_states: Optional[torch.Tensor] = None
|
56
|
+
|
49
57
|
|
50
58
|
@dataclasses.dataclass
|
51
59
|
class LogitsMetadata:
|
@@ -61,6 +69,8 @@ class LogitsMetadata:
|
|
61
69
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
62
70
|
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
63
71
|
|
72
|
+
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
73
|
+
|
64
74
|
@classmethod
|
65
75
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
66
76
|
extend_logprob_pruned_lens_cpu = None
|
@@ -78,6 +88,11 @@ class LogitsMetadata:
|
|
78
88
|
else:
|
79
89
|
return_top_logprob = False
|
80
90
|
|
91
|
+
if forward_batch.spec_info:
|
92
|
+
capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
|
93
|
+
else:
|
94
|
+
capture_hidden_mode = CaptureHiddenMode.NULL
|
95
|
+
|
81
96
|
return cls(
|
82
97
|
forward_mode=forward_batch.forward_mode,
|
83
98
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
@@ -87,6 +102,7 @@ class LogitsMetadata:
|
|
87
102
|
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
88
103
|
extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
|
89
104
|
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
105
|
+
capture_hidden_mode=capture_hidden_mode,
|
90
106
|
)
|
91
107
|
|
92
108
|
|
@@ -116,7 +132,10 @@ class LogitsProcessor(nn.Module):
|
|
116
132
|
assert isinstance(logits_metadata, LogitsMetadata)
|
117
133
|
|
118
134
|
# Get the last hidden states and last logits for the next token prediction
|
119
|
-
if
|
135
|
+
if (
|
136
|
+
logits_metadata.forward_mode.is_decode()
|
137
|
+
or logits_metadata.forward_mode.is_target_verify()
|
138
|
+
):
|
120
139
|
last_index = None
|
121
140
|
last_hidden = hidden_states
|
122
141
|
else:
|
@@ -137,6 +156,15 @@ class LogitsProcessor(nn.Module):
|
|
137
156
|
if not logits_metadata.return_logprob:
|
138
157
|
return LogitsProcessorOutput(
|
139
158
|
next_token_logits=last_logits,
|
159
|
+
hidden_states=(
|
160
|
+
hidden_states
|
161
|
+
if logits_metadata.capture_hidden_mode.is_full()
|
162
|
+
else (
|
163
|
+
last_hidden
|
164
|
+
if logits_metadata.capture_hidden_mode.is_last()
|
165
|
+
else None
|
166
|
+
)
|
167
|
+
),
|
140
168
|
)
|
141
169
|
else:
|
142
170
|
last_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 5
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 256,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 32,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 32,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 64,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 64,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|