sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +19 -3
- sglang/bench_serving.py +8 -8
- sglang/compile_deep_gemm.py +177 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +1 -1
- sglang/srt/configs/model_config.py +11 -2
- sglang/srt/constrained/llguidance_backend.py +78 -61
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +34 -1
- sglang/srt/disaggregation/decode.py +96 -5
- sglang/srt/disaggregation/mini_lb.py +113 -15
- sglang/srt/disaggregation/mooncake/conn.py +199 -32
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +119 -20
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +11 -9
- sglang/srt/function_call_parser.py +132 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/base_attn_backend.py +3 -0
- sglang/srt/layers/attention/flashattention_backend.py +809 -160
- sglang/srt/layers/attention/flashmla_backend.py +8 -11
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/attention/vision.py +2 -0
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +42 -5
- sglang/srt/layers/logits_processor.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
- sglang/srt/layers/pooler.py +6 -0
- sglang/srt/layers/quantization/awq.py +5 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +385 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/int8_kernel.py +32 -1
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/radix_attention.py +13 -3
- sglang/srt/layers/rotary_embedding.py +176 -132
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +17 -4
- sglang/srt/managers/io_struct.py +21 -3
- sglang/srt/managers/mm_utils.py +85 -28
- sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
- sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
- sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
- sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
- sglang/srt/managers/schedule_batch.py +42 -12
- sglang/srt/managers/scheduler.py +47 -26
- sglang/srt/managers/tokenizer_manager.py +120 -30
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +118 -13
- sglang/srt/model_executor/cuda_graph_runner.py +16 -10
- sglang/srt/model_executor/forward_batch_info.py +51 -95
- sglang/srt/model_executor/model_runner.py +29 -27
- sglang/srt/models/deepseek.py +12 -2
- sglang/srt/models/deepseek_nextn.py +101 -6
- sglang/srt/models/deepseek_v2.py +153 -76
- sglang/srt/models/deepseek_vl2.py +9 -4
- sglang/srt/models/gemma3_causal.py +1 -1
- sglang/srt/models/llama4.py +0 -1
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +22 -7
- sglang/srt/models/mllama4.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +3 -6
- sglang/srt/models/qwen2_vl.py +3 -7
- sglang/srt/models/roberta.py +178 -0
- sglang/srt/openai_api/adapter.py +87 -10
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +65 -60
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/torch_memory_saver_adapter.py +10 -1
- sglang/srt/utils.py +48 -6
- sglang/test/runners.py +6 -13
- sglang/test/test_utils.py +39 -19
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
68
68
|
self.num_q_heads = (
|
69
69
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
70
70
|
)
|
71
|
-
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
72
|
-
get_attention_tp_size()
|
73
|
-
)
|
74
71
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
75
72
|
self.num_local_heads = (
|
76
73
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
111
108
|
)
|
112
109
|
mla_metadata, num_splits = get_mla_metadata(
|
113
110
|
forward_batch.seq_lens.to(torch.int32),
|
114
|
-
Q_LEN * self.num_q_heads
|
115
|
-
|
111
|
+
Q_LEN * self.num_q_heads,
|
112
|
+
1,
|
116
113
|
)
|
117
114
|
self.forward_metadata = FlashMLADecodeMetadata(
|
118
115
|
mla_metadata,
|
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
141
138
|
|
142
139
|
self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
|
143
140
|
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
|
144
|
-
Q_LEN * self.num_q_heads
|
145
|
-
|
141
|
+
Q_LEN * self.num_q_heads,
|
142
|
+
1,
|
146
143
|
)
|
147
144
|
self.cuda_graph_kv_indices = cuda_graph_kv_indices
|
148
145
|
|
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
171
168
|
)
|
172
169
|
mla_metadata, num_splits = get_mla_metadata(
|
173
170
|
seq_lens.to(torch.int32),
|
174
|
-
Q_LEN * self.num_q_heads
|
175
|
-
|
171
|
+
Q_LEN * self.num_q_heads,
|
172
|
+
1,
|
176
173
|
)
|
177
174
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
178
175
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
221
218
|
)
|
222
219
|
mla_metadata, num_splits = get_mla_metadata(
|
223
220
|
seq_lens.to(torch.int32),
|
224
|
-
Q_LEN * self.num_q_heads
|
225
|
-
|
221
|
+
Q_LEN * self.num_q_heads,
|
222
|
+
1,
|
226
223
|
)
|
227
224
|
self.cuda_graph_mla_metadata.copy_(mla_metadata)
|
228
225
|
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
|
@@ -3,10 +3,10 @@ import triton
|
|
3
3
|
import triton.language as tl
|
4
4
|
|
5
5
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
6
|
-
from sglang.srt.utils import is_hip
|
6
|
+
from sglang.srt.utils import is_cuda, is_hip
|
7
7
|
|
8
|
-
|
9
|
-
if
|
8
|
+
_is_cuda = is_cuda()
|
9
|
+
if _is_cuda:
|
10
10
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
11
11
|
|
12
12
|
_is_hip = is_hip()
|
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
|
|
1037
1037
|
num_warps = 4
|
1038
1038
|
|
1039
1039
|
else:
|
1040
|
-
if
|
1040
|
+
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
1041
1041
|
if Lq <= 256:
|
1042
1042
|
BLOCK_M, BLOCK_N = (128, 64)
|
1043
1043
|
else:
|
1044
1044
|
BLOCK_M, BLOCK_N = (32, 64)
|
1045
|
-
elif
|
1045
|
+
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
1046
1046
|
if Lq <= 128:
|
1047
1047
|
BLOCK_M, BLOCK_N = (128, 128)
|
1048
1048
|
elif Lq <= 256:
|
@@ -23,10 +23,10 @@ import triton.language as tl
|
|
23
23
|
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
24
24
|
context_attention_fwd,
|
25
25
|
)
|
26
|
-
from sglang.srt.utils import is_hip
|
26
|
+
from sglang.srt.utils import is_cuda, is_hip
|
27
27
|
|
28
|
-
|
29
|
-
if
|
28
|
+
_is_cuda = is_cuda()
|
29
|
+
if _is_cuda:
|
30
30
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
31
31
|
|
32
32
|
_is_hip = is_hip()
|
@@ -345,12 +345,12 @@ def extend_attention_fwd(
|
|
345
345
|
num_warps = 4
|
346
346
|
|
347
347
|
else:
|
348
|
-
if
|
348
|
+
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
|
349
349
|
if Lq <= 256:
|
350
350
|
BLOCK_M, BLOCK_N = (128, 64)
|
351
351
|
else:
|
352
352
|
BLOCK_M, BLOCK_N = (32, 64)
|
353
|
-
elif
|
353
|
+
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
|
354
354
|
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
|
355
355
|
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
|
356
356
|
if Lq <= 128:
|
@@ -22,8 +22,12 @@ import torch
|
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
|
26
|
-
|
25
|
+
from sglang.srt.utils import is_cuda, is_hip
|
26
|
+
|
27
|
+
_is_cuda = is_cuda()
|
28
|
+
_is_hip = is_hip()
|
29
|
+
|
30
|
+
if _is_cuda or _is_hip:
|
27
31
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
28
32
|
|
29
33
|
|
@@ -172,7 +176,7 @@ def context_attention_fwd(
|
|
172
176
|
b_seq_len: [b]
|
173
177
|
out: [b * s, head, head_dim]
|
174
178
|
"""
|
175
|
-
if
|
179
|
+
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
|
176
180
|
BLOCK = 128
|
177
181
|
else:
|
178
182
|
BLOCK = 64
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -20,9 +20,10 @@ import torch
|
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
22
|
from sglang.srt.custom_op import CustomOp
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import is_cuda, is_hip
|
24
24
|
|
25
|
-
_is_cuda =
|
25
|
+
_is_cuda = is_cuda()
|
26
|
+
_is_hip = is_hip()
|
26
27
|
|
27
28
|
if _is_cuda:
|
28
29
|
from sgl_kernel import (
|
@@ -32,6 +33,8 @@ if _is_cuda:
|
|
32
33
|
rmsnorm,
|
33
34
|
)
|
34
35
|
|
36
|
+
if _is_hip:
|
37
|
+
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
35
38
|
|
36
39
|
logger = logging.getLogger(__name__)
|
37
40
|
|
@@ -46,23 +49,49 @@ class RMSNorm(CustomOp):
|
|
46
49
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
47
50
|
self.variance_epsilon = eps
|
48
51
|
|
52
|
+
def forward(self, *args, **kwargs):
|
53
|
+
if torch.compiler.is_compiling():
|
54
|
+
return self.forward_native(*args, **kwargs)
|
55
|
+
if _is_cuda:
|
56
|
+
return self.forward_cuda(*args, **kwargs)
|
57
|
+
elif _is_hip:
|
58
|
+
return self.forward_hip(*args, **kwargs)
|
59
|
+
else:
|
60
|
+
return self.forward_native(*args, **kwargs)
|
61
|
+
|
49
62
|
def forward_cuda(
|
50
63
|
self,
|
51
64
|
x: torch.Tensor,
|
52
65
|
residual: Optional[torch.Tensor] = None,
|
53
66
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
54
|
-
|
55
67
|
if residual is not None:
|
56
68
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
57
69
|
return x, residual
|
58
70
|
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
59
71
|
return out
|
60
72
|
|
73
|
+
def forward_hip(
|
74
|
+
self,
|
75
|
+
x: torch.Tensor,
|
76
|
+
residual: Optional[torch.Tensor] = None,
|
77
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
78
|
+
if not x.is_contiguous():
|
79
|
+
# NOTE: Romove this if aiter kernel supports discontinuous input
|
80
|
+
x = x.contiguous()
|
81
|
+
if residual is not None:
|
82
|
+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
83
|
+
return x, residual
|
84
|
+
out = torch.empty_like(x)
|
85
|
+
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
86
|
+
return out
|
87
|
+
|
61
88
|
def forward_native(
|
62
89
|
self,
|
63
90
|
x: torch.Tensor,
|
64
91
|
residual: Optional[torch.Tensor] = None,
|
65
92
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
93
|
+
if not x.is_contiguous():
|
94
|
+
x = x.contiguous()
|
66
95
|
orig_dtype = x.dtype
|
67
96
|
x = x.to(torch.float32)
|
68
97
|
if residual is not None:
|
@@ -88,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
|
|
88
117
|
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
89
118
|
self.variance_epsilon = eps
|
90
119
|
|
120
|
+
def forward(self, *args, **kwargs):
|
121
|
+
if torch.compiler.is_compiling():
|
122
|
+
return self.forward_native(*args, **kwargs)
|
123
|
+
if _is_cuda:
|
124
|
+
return self.forward_cuda(*args, **kwargs)
|
125
|
+
else:
|
126
|
+
return self.forward_native(*args, **kwargs)
|
127
|
+
|
91
128
|
def forward_native(
|
92
129
|
self,
|
93
130
|
x: torch.Tensor,
|
@@ -139,8 +176,8 @@ class Gemma3RMSNorm(nn.Module):
|
|
139
176
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
140
177
|
|
141
178
|
|
142
|
-
if not _is_cuda:
|
179
|
+
if not (_is_cuda or _is_hip):
|
143
180
|
logger.info(
|
144
|
-
"sgl-kernel is not available on
|
181
|
+
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
145
182
|
)
|
146
183
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
@@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module):
|
|
335
335
|
aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
|
336
336
|
hidden_states_to_store = (
|
337
337
|
aux_pruned_states[sample_indices]
|
338
|
-
if sample_indices
|
338
|
+
if sample_indices is not None
|
339
339
|
else aux_pruned_states
|
340
340
|
)
|
341
341
|
else:
|
342
342
|
hidden_states_to_store = (
|
343
343
|
pruned_states[sample_indices]
|
344
|
-
if sample_indices
|
344
|
+
if sample_indices is not None
|
345
345
|
else pruned_states
|
346
346
|
)
|
347
347
|
else:
|
@@ -802,6 +802,7 @@ class DeepEPMoE(EPMoE):
|
|
802
802
|
correction_bias: Optional[torch.Tensor] = None,
|
803
803
|
custom_routing_function: Optional[Callable] = None,
|
804
804
|
activation: str = "silu",
|
805
|
+
routed_scaling_factor: Optional[float] = None,
|
805
806
|
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
806
807
|
):
|
807
808
|
super().__init__(
|
@@ -820,6 +821,7 @@ class DeepEPMoE(EPMoE):
|
|
820
821
|
correction_bias,
|
821
822
|
custom_routing_function,
|
822
823
|
activation,
|
824
|
+
routed_scaling_factor,
|
823
825
|
)
|
824
826
|
self.deepep_mode = deepep_mode
|
825
827
|
if self.deepep_mode.enable_low_latency():
|
@@ -8,6 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
+
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
11
12
|
from sglang.srt.layers.moe.topk import select_experts
|
12
13
|
|
13
14
|
|
@@ -30,7 +31,7 @@ def fused_moe_forward_native(
|
|
30
31
|
) -> torch.Tensor:
|
31
32
|
|
32
33
|
if apply_router_weight_on_input:
|
33
|
-
raise NotImplementedError
|
34
|
+
raise NotImplementedError()
|
34
35
|
|
35
36
|
topk_weights, topk_ids = select_experts(
|
36
37
|
hidden_states=x,
|
@@ -75,9 +76,6 @@ def moe_forward_native(
|
|
75
76
|
activation: str = "silu",
|
76
77
|
routed_scaling_factor: Optional[float] = None,
|
77
78
|
) -> torch.Tensor:
|
78
|
-
|
79
|
-
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
80
|
-
|
81
79
|
topk_weights, topk_ids = select_experts(
|
82
80
|
hidden_states=x,
|
83
81
|
router_logits=router_logits,
|
@@ -1,102 +1,102 @@
|
|
1
1
|
{
|
2
2
|
"1": {
|
3
|
-
"BLOCK_SIZE_M":
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
4
|
"BLOCK_SIZE_N": 64,
|
5
5
|
"BLOCK_SIZE_K": 128,
|
6
|
-
"GROUP_SIZE_M":
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
7
|
"num_warps": 4,
|
8
8
|
"num_stages": 4
|
9
9
|
},
|
10
10
|
"2": {
|
11
|
-
"BLOCK_SIZE_M":
|
12
|
-
"BLOCK_SIZE_N":
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
13
|
"BLOCK_SIZE_K": 128,
|
14
|
-
"GROUP_SIZE_M":
|
14
|
+
"GROUP_SIZE_M": 16,
|
15
15
|
"num_warps": 4,
|
16
|
-
"num_stages":
|
16
|
+
"num_stages": 4
|
17
17
|
},
|
18
18
|
"4": {
|
19
|
-
"BLOCK_SIZE_M":
|
20
|
-
"BLOCK_SIZE_N":
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
21
|
"BLOCK_SIZE_K": 128,
|
22
|
-
"GROUP_SIZE_M":
|
22
|
+
"GROUP_SIZE_M": 16,
|
23
23
|
"num_warps": 4,
|
24
24
|
"num_stages": 4
|
25
25
|
},
|
26
26
|
"8": {
|
27
|
-
"BLOCK_SIZE_M":
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
28
|
"BLOCK_SIZE_N": 128,
|
29
29
|
"BLOCK_SIZE_K": 128,
|
30
30
|
"GROUP_SIZE_M": 32,
|
31
31
|
"num_warps": 4,
|
32
|
-
"num_stages":
|
32
|
+
"num_stages": 4
|
33
33
|
},
|
34
34
|
"16": {
|
35
|
-
"BLOCK_SIZE_M":
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
36
|
"BLOCK_SIZE_N": 128,
|
37
37
|
"BLOCK_SIZE_K": 128,
|
38
|
-
"GROUP_SIZE_M":
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
39
|
"num_warps": 4,
|
40
40
|
"num_stages": 3
|
41
41
|
},
|
42
42
|
"24": {
|
43
|
-
"BLOCK_SIZE_M":
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
44
|
"BLOCK_SIZE_N": 128,
|
45
45
|
"BLOCK_SIZE_K": 128,
|
46
|
-
"GROUP_SIZE_M":
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
47
|
"num_warps": 4,
|
48
|
-
"num_stages":
|
48
|
+
"num_stages": 4
|
49
49
|
},
|
50
50
|
"32": {
|
51
|
-
"BLOCK_SIZE_M":
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
52
|
"BLOCK_SIZE_N": 128,
|
53
53
|
"BLOCK_SIZE_K": 128,
|
54
|
-
"GROUP_SIZE_M":
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
55
|
"num_warps": 4,
|
56
|
-
"num_stages":
|
56
|
+
"num_stages": 5
|
57
57
|
},
|
58
58
|
"48": {
|
59
|
-
"BLOCK_SIZE_M":
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
60
|
"BLOCK_SIZE_N": 128,
|
61
61
|
"BLOCK_SIZE_K": 128,
|
62
|
-
"GROUP_SIZE_M":
|
62
|
+
"GROUP_SIZE_M": 64,
|
63
63
|
"num_warps": 4,
|
64
|
-
"num_stages":
|
64
|
+
"num_stages": 4
|
65
65
|
},
|
66
66
|
"64": {
|
67
|
-
"BLOCK_SIZE_M":
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
68
|
"BLOCK_SIZE_N": 128,
|
69
69
|
"BLOCK_SIZE_K": 128,
|
70
|
-
"GROUP_SIZE_M":
|
70
|
+
"GROUP_SIZE_M": 32,
|
71
71
|
"num_warps": 4,
|
72
72
|
"num_stages": 3
|
73
73
|
},
|
74
74
|
"96": {
|
75
|
-
"BLOCK_SIZE_M":
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
76
|
"BLOCK_SIZE_N": 128,
|
77
77
|
"BLOCK_SIZE_K": 128,
|
78
|
-
"GROUP_SIZE_M":
|
78
|
+
"GROUP_SIZE_M": 32,
|
79
79
|
"num_warps": 4,
|
80
80
|
"num_stages": 3
|
81
81
|
},
|
82
82
|
"128": {
|
83
|
-
"BLOCK_SIZE_M":
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
84
|
"BLOCK_SIZE_N": 128,
|
85
85
|
"BLOCK_SIZE_K": 128,
|
86
|
-
"GROUP_SIZE_M":
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
87
|
"num_warps": 4,
|
88
88
|
"num_stages": 3
|
89
89
|
},
|
90
90
|
"256": {
|
91
|
-
"BLOCK_SIZE_M":
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
92
|
"BLOCK_SIZE_N": 128,
|
93
93
|
"BLOCK_SIZE_K": 128,
|
94
|
-
"GROUP_SIZE_M":
|
94
|
+
"GROUP_SIZE_M": 64,
|
95
95
|
"num_warps": 4,
|
96
96
|
"num_stages": 3
|
97
97
|
},
|
98
98
|
"512": {
|
99
|
-
"BLOCK_SIZE_M":
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
100
|
"BLOCK_SIZE_N": 128,
|
101
101
|
"BLOCK_SIZE_K": 128,
|
102
102
|
"GROUP_SIZE_M": 16,
|
@@ -107,9 +107,9 @@
|
|
107
107
|
"BLOCK_SIZE_M": 64,
|
108
108
|
"BLOCK_SIZE_N": 128,
|
109
109
|
"BLOCK_SIZE_K": 128,
|
110
|
-
"GROUP_SIZE_M":
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
111
|
"num_warps": 4,
|
112
|
-
"num_stages":
|
112
|
+
"num_stages": 4
|
113
113
|
},
|
114
114
|
"1536": {
|
115
115
|
"BLOCK_SIZE_M": 64,
|
@@ -117,21 +117,21 @@
|
|
117
117
|
"BLOCK_SIZE_K": 128,
|
118
118
|
"GROUP_SIZE_M": 32,
|
119
119
|
"num_warps": 4,
|
120
|
-
"num_stages":
|
120
|
+
"num_stages": 4
|
121
121
|
},
|
122
122
|
"2048": {
|
123
123
|
"BLOCK_SIZE_M": 64,
|
124
124
|
"BLOCK_SIZE_N": 128,
|
125
125
|
"BLOCK_SIZE_K": 128,
|
126
|
-
"GROUP_SIZE_M":
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
127
|
"num_warps": 4,
|
128
|
-
"num_stages":
|
128
|
+
"num_stages": 4
|
129
129
|
},
|
130
130
|
"3072": {
|
131
|
-
"BLOCK_SIZE_M":
|
132
|
-
"BLOCK_SIZE_N":
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
133
|
"BLOCK_SIZE_K": 128,
|
134
|
-
"GROUP_SIZE_M":
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
135
|
"num_warps": 4,
|
136
136
|
"num_stages": 3
|
137
137
|
},
|
@@ -139,8 +139,8 @@
|
|
139
139
|
"BLOCK_SIZE_M": 64,
|
140
140
|
"BLOCK_SIZE_N": 128,
|
141
141
|
"BLOCK_SIZE_K": 128,
|
142
|
-
"GROUP_SIZE_M":
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
143
|
"num_warps": 4,
|
144
|
-
"num_stages":
|
144
|
+
"num_stages": 4
|
145
145
|
}
|
146
146
|
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 4
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 64,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 256,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 32,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 128,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 128,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 128,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 128,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 256,
|
101
|
+
"BLOCK_SIZE_K": 64,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|