sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/layers/attention/flashinfer_backend.py +34 -41
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/model_executor/cuda_graph_runner.py +4 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/check_env.py
CHANGED
@@ -35,7 +35,10 @@ is_hip_ = is_hip()
|
|
35
35
|
if is_hip_:
|
36
36
|
from outlines_core.fsm.json_schema import build_regex_from_schema
|
37
37
|
else:
|
38
|
-
|
38
|
+
try:
|
39
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
40
|
+
except ImportError:
|
41
|
+
from outlines_core.fsm.json_schema import build_regex_from_schema
|
39
42
|
|
40
43
|
|
41
44
|
logger = logging.getLogger(__name__)
|
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
70
70
|
):
|
71
71
|
super().__init__()
|
72
72
|
|
73
|
+
self.is_multimodal = model_runner.model_config.is_multimodal
|
74
|
+
|
73
75
|
# Parse constants
|
74
76
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
75
77
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
130
132
|
for _ in range(self.num_wrappers)
|
131
133
|
]
|
132
134
|
|
133
|
-
|
134
|
-
|
135
|
-
self.prefill_wrapper_ragged = (
|
136
|
-
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
137
|
-
if self.num_wrappers == 1
|
138
|
-
else None
|
135
|
+
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
136
|
+
self.workspace_buffer, "NHD"
|
139
137
|
)
|
140
138
|
|
141
139
|
# Two wrappers: one for sliding window attention and one for full attention.
|
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
217
215
|
else:
|
218
216
|
prefix_lens = forward_batch.extend_prefix_lens
|
219
217
|
|
220
|
-
|
221
|
-
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
222
|
-
use_ragged = True
|
223
|
-
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
224
|
-
else:
|
218
|
+
if self.is_multimodal:
|
225
219
|
use_ragged = False
|
226
220
|
extend_no_prefix = False
|
221
|
+
else:
|
222
|
+
use_ragged = True
|
223
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
227
224
|
|
228
225
|
self.indices_updater_prefill.update(
|
229
226
|
forward_batch.req_pool_indices,
|
@@ -409,9 +406,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
409
406
|
)
|
410
407
|
else:
|
411
408
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
412
|
-
q.
|
413
|
-
k.
|
414
|
-
v.
|
409
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
410
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
411
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
415
412
|
causal=True,
|
416
413
|
sm_scale=layer.scaling,
|
417
414
|
logits_soft_cap=logits_soft_cap,
|
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
|
|
640
637
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
641
638
|
bs = kv_indptr.shape[0] - 1
|
642
639
|
|
643
|
-
wrapper.end_forward()
|
644
640
|
wrapper.begin_forward(
|
645
641
|
kv_indptr,
|
646
642
|
kv_indices,
|
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
|
|
651
647
|
1,
|
652
648
|
data_type=self.data_type,
|
653
649
|
q_data_type=self.q_data_type,
|
650
|
+
non_blocking=True,
|
654
651
|
)
|
655
652
|
|
656
653
|
|
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
860
857
|
|
861
858
|
# extend part
|
862
859
|
if use_ragged:
|
863
|
-
wrapper_ragged.end_forward()
|
864
860
|
wrapper_ragged.begin_forward(
|
865
861
|
qo_indptr,
|
866
862
|
qo_indptr,
|
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|
871
867
|
)
|
872
868
|
|
873
869
|
# cached part
|
874
|
-
wrapper_paged.end_forward()
|
875
870
|
wrapper_paged.begin_forward(
|
876
871
|
qo_indptr,
|
877
872
|
kv_indptr,
|
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|
883
878
|
1,
|
884
879
|
q_data_type=self.q_data_type,
|
885
880
|
custom_mask=custom_mask,
|
881
|
+
non_blocking=True,
|
886
882
|
)
|
887
883
|
|
888
884
|
|
@@ -924,38 +920,50 @@ class FlashInferMultiStepDraftBackend:
|
|
924
920
|
self.max_context_len = self.attn_backends[0].max_context_len
|
925
921
|
# Cached variables for generate_draft_decode_kv_indices
|
926
922
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
927
|
-
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
928
923
|
|
929
|
-
def common_template(
|
924
|
+
def common_template(
|
925
|
+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
926
|
+
):
|
930
927
|
num_seqs = forward_batch.batch_size
|
931
928
|
bs = self.topk * num_seqs
|
932
929
|
seq_lens_sum = forward_batch.seq_lens_sum
|
930
|
+
|
933
931
|
self.generate_draft_decode_kv_indices[
|
934
932
|
(self.speculative_num_steps, num_seqs, self.topk)
|
935
933
|
](
|
936
934
|
forward_batch.req_pool_indices,
|
937
935
|
forward_batch.req_to_token_pool.req_to_token,
|
938
936
|
forward_batch.seq_lens,
|
939
|
-
|
937
|
+
kv_indices_buffer,
|
940
938
|
self.kv_indptr,
|
941
939
|
forward_batch.positions,
|
942
940
|
num_seqs,
|
943
941
|
self.topk,
|
944
942
|
self.pool_len,
|
945
|
-
|
943
|
+
kv_indices_buffer.shape[1],
|
946
944
|
self.kv_indptr.shape[1],
|
947
945
|
triton.next_power_of_2(num_seqs),
|
948
946
|
triton.next_power_of_2(self.speculative_num_steps),
|
949
947
|
triton.next_power_of_2(bs),
|
950
948
|
)
|
949
|
+
|
951
950
|
for i in range(self.speculative_num_steps):
|
952
951
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
953
|
-
forward_batch.spec_info.kv_indices =
|
952
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
954
953
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
955
954
|
]
|
956
955
|
call_fn(i, forward_batch)
|
957
956
|
|
958
957
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
958
|
+
kv_indices = torch.zeros(
|
959
|
+
(
|
960
|
+
self.speculative_num_steps,
|
961
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
962
|
+
),
|
963
|
+
dtype=torch.int32,
|
964
|
+
device="cuda",
|
965
|
+
)
|
966
|
+
|
959
967
|
def call_fn(i, forward_batch):
|
960
968
|
forward_batch.spec_info.kv_indptr = (
|
961
969
|
forward_batch.spec_info.kv_indptr.clone()
|
@@ -965,7 +973,7 @@ class FlashInferMultiStepDraftBackend:
|
|
965
973
|
)
|
966
974
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
967
975
|
|
968
|
-
self.common_template(forward_batch, call_fn)
|
976
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
969
977
|
|
970
978
|
def init_cuda_graph_state(self, max_bs: int):
|
971
979
|
self.cuda_graph_kv_indices = torch.zeros(
|
@@ -973,7 +981,6 @@ class FlashInferMultiStepDraftBackend:
|
|
973
981
|
dtype=torch.int32,
|
974
982
|
device="cuda",
|
975
983
|
)
|
976
|
-
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
977
984
|
for i in range(self.speculative_num_steps):
|
978
985
|
self.attn_backends[i].init_cuda_graph_state(
|
979
986
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
@@ -995,7 +1002,7 @@ class FlashInferMultiStepDraftBackend:
|
|
995
1002
|
][0]
|
996
1003
|
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
997
1004
|
|
998
|
-
self.common_template(forward_batch, call_fn)
|
1005
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
999
1006
|
|
1000
1007
|
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
1001
1008
|
def call_fn(i, forward_batch):
|
@@ -1009,7 +1016,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1009
1016
|
spec_info=forward_batch.spec_info,
|
1010
1017
|
)
|
1011
1018
|
|
1012
|
-
self.common_template(forward_batch, call_fn)
|
1019
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1013
1020
|
|
1014
1021
|
|
1015
1022
|
@triton.jit
|
@@ -1070,21 +1077,6 @@ def should_use_tensor_core(
|
|
1070
1077
|
if env_override is not None:
|
1071
1078
|
return env_override.lower() == "true"
|
1072
1079
|
|
1073
|
-
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1074
|
-
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1075
|
-
try:
|
1076
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1077
|
-
|
1078
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
1079
|
-
num_attention_heads,
|
1080
|
-
num_kv_heads,
|
1081
|
-
):
|
1082
|
-
return True
|
1083
|
-
else:
|
1084
|
-
return False
|
1085
|
-
except (ImportError, AttributeError):
|
1086
|
-
pass
|
1087
|
-
|
1088
1080
|
# Calculate GQA group size
|
1089
1081
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1090
1082
|
|
@@ -1114,6 +1106,7 @@ def fast_decode_plan(
|
|
1114
1106
|
sm_scale: Optional[float] = None,
|
1115
1107
|
rope_scale: Optional[float] = None,
|
1116
1108
|
rope_theta: Optional[float] = None,
|
1109
|
+
**kwargs,
|
1117
1110
|
) -> None:
|
1118
1111
|
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
1119
1112
|
batch_size = len(last_page_len)
|
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
|
18
18
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
19
19
|
|
20
20
|
is_hip_flag = is_hip()
|
21
|
-
|
21
|
+
|
22
22
|
|
23
23
|
logger = logging.getLogger(__name__)
|
24
24
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
|
|
27
27
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
28
28
|
)
|
29
29
|
|
30
|
+
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
31
|
+
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
32
|
+
|
33
|
+
if _is_cuda:
|
34
|
+
from sgl_kernel import gelu_and_mul, silu_and_mul
|
35
|
+
|
36
|
+
if _is_cuda or _is_rocm:
|
37
|
+
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
38
|
+
|
30
39
|
|
31
40
|
@triton.jit
|
32
41
|
def fused_moe_kernel(
|
@@ -989,9 +998,15 @@ def fused_experts_impl(
|
|
989
998
|
)
|
990
999
|
|
991
1000
|
if activation == "silu":
|
992
|
-
|
1001
|
+
if _is_cuda:
|
1002
|
+
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1003
|
+
else:
|
1004
|
+
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
993
1005
|
elif activation == "gelu":
|
994
|
-
|
1006
|
+
if _is_cuda:
|
1007
|
+
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1008
|
+
else:
|
1009
|
+
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
|
995
1010
|
else:
|
996
1011
|
raise ValueError(f"Unsupported activation: {activation=}")
|
997
1012
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 8,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 5
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 8,
|
56
|
+
"num_stages": 5
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 4
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 128,
|
76
|
+
"BLOCK_SIZE_N": 32,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 8,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 32,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 32,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 64,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 64,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 5
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 64,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 64,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 64,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 5
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 64,
|
44
|
+
"BLOCK_SIZE_N": 32,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 5
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 64,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 64,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 64,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 64,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 64,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 5
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 64,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 5
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 64,
|
92
|
+
"BLOCK_SIZE_N": 32,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 32,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 1,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|