sglang 0.4.2.post3__py3-none-any.whl → 0.4.2.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sglang/check_env.py +1 -0
  2. sglang/srt/constrained/outlines_backend.py +4 -1
  3. sglang/srt/layers/attention/flashinfer_backend.py +34 -41
  4. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -3
  5. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  6. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  7. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  8. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  9. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  10. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  11. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  13. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  14. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/lora/backend/__init__.py +25 -5
  19. sglang/srt/lora/backend/base_backend.py +31 -9
  20. sglang/srt/lora/backend/flashinfer_backend.py +41 -4
  21. sglang/srt/lora/backend/triton_backend.py +34 -4
  22. sglang/srt/lora/layers.py +293 -0
  23. sglang/srt/lora/lora.py +101 -326
  24. sglang/srt/lora/lora_manager.py +101 -269
  25. sglang/srt/lora/mem_pool.py +174 -0
  26. sglang/srt/lora/triton_ops/__init__.py +7 -1
  27. sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
  28. sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
  29. sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
  30. sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
  31. sglang/srt/lora/utils.py +141 -0
  32. sglang/srt/model_executor/cuda_graph_runner.py +4 -0
  33. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  34. sglang/srt/speculative/eagle_utils.py +64 -21
  35. sglang/srt/speculative/eagle_worker.py +1 -0
  36. sglang/version.py +1 -1
  37. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/METADATA +4 -4
  38. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/RECORD +41 -24
  39. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/LICENSE +0 -0
  40. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/WHEEL +0 -0
  41. {sglang-0.4.2.post3.dist-info → sglang-0.4.2.post4.dist-info}/top_level.txt +0 -0
sglang/check_env.py CHANGED
@@ -19,6 +19,7 @@ def is_cuda_v2():
19
19
  # List of packages to check versions
20
20
  PACKAGE_LIST = [
21
21
  "sglang",
22
+ "sgl_kernel",
22
23
  "flashinfer",
23
24
  "triton",
24
25
  "transformers",
@@ -35,7 +35,10 @@ is_hip_ = is_hip()
35
35
  if is_hip_:
36
36
  from outlines_core.fsm.json_schema import build_regex_from_schema
37
37
  else:
38
- from outlines.fsm.json_schema import build_regex_from_schema
38
+ try:
39
+ from outlines.fsm.json_schema import build_regex_from_schema
40
+ except ImportError:
41
+ from outlines_core.fsm.json_schema import build_regex_from_schema
39
42
 
40
43
 
41
44
  logger = logging.getLogger(__name__)
@@ -70,6 +70,8 @@ class FlashInferAttnBackend(AttentionBackend):
70
70
  ):
71
71
  super().__init__()
72
72
 
73
+ self.is_multimodal = model_runner.model_config.is_multimodal
74
+
73
75
  # Parse constants
74
76
  self.decode_use_tensor_cores = should_use_tensor_core(
75
77
  kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -130,12 +132,8 @@ class FlashInferAttnBackend(AttentionBackend):
130
132
  for _ in range(self.num_wrappers)
131
133
  ]
132
134
 
133
- # Create wrappers
134
- # NOTE: we do not use ragged attention when there are multiple wrappers
135
- self.prefill_wrapper_ragged = (
136
- BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
137
- if self.num_wrappers == 1
138
- else None
135
+ self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
136
+ self.workspace_buffer, "NHD"
139
137
  )
140
138
 
141
139
  # Two wrappers: one for sliding window attention and one for full attention.
@@ -217,13 +215,12 @@ class FlashInferAttnBackend(AttentionBackend):
217
215
  else:
218
216
  prefix_lens = forward_batch.extend_prefix_lens
219
217
 
220
- # Some heuristics to check whether to use ragged forward
221
- if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
222
- use_ragged = True
223
- extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
224
- else:
218
+ if self.is_multimodal:
225
219
  use_ragged = False
226
220
  extend_no_prefix = False
221
+ else:
222
+ use_ragged = True
223
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
227
224
 
228
225
  self.indices_updater_prefill.update(
229
226
  forward_batch.req_pool_indices,
@@ -409,9 +406,9 @@ class FlashInferAttnBackend(AttentionBackend):
409
406
  )
410
407
  else:
411
408
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
412
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
413
- k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim),
414
- v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim),
409
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
410
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
411
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
415
412
  causal=True,
416
413
  sm_scale=layer.scaling,
417
414
  logits_soft_cap=logits_soft_cap,
@@ -640,7 +637,6 @@ class FlashInferIndicesUpdaterDecode:
640
637
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
641
638
  bs = kv_indptr.shape[0] - 1
642
639
 
643
- wrapper.end_forward()
644
640
  wrapper.begin_forward(
645
641
  kv_indptr,
646
642
  kv_indices,
@@ -651,6 +647,7 @@ class FlashInferIndicesUpdaterDecode:
651
647
  1,
652
648
  data_type=self.data_type,
653
649
  q_data_type=self.q_data_type,
650
+ non_blocking=True,
654
651
  )
655
652
 
656
653
 
@@ -860,7 +857,6 @@ class FlashInferIndicesUpdaterPrefill:
860
857
 
861
858
  # extend part
862
859
  if use_ragged:
863
- wrapper_ragged.end_forward()
864
860
  wrapper_ragged.begin_forward(
865
861
  qo_indptr,
866
862
  qo_indptr,
@@ -871,7 +867,6 @@ class FlashInferIndicesUpdaterPrefill:
871
867
  )
872
868
 
873
869
  # cached part
874
- wrapper_paged.end_forward()
875
870
  wrapper_paged.begin_forward(
876
871
  qo_indptr,
877
872
  kv_indptr,
@@ -883,6 +878,7 @@ class FlashInferIndicesUpdaterPrefill:
883
878
  1,
884
879
  q_data_type=self.q_data_type,
885
880
  custom_mask=custom_mask,
881
+ non_blocking=True,
886
882
  )
887
883
 
888
884
 
@@ -924,38 +920,50 @@ class FlashInferMultiStepDraftBackend:
924
920
  self.max_context_len = self.attn_backends[0].max_context_len
925
921
  # Cached variables for generate_draft_decode_kv_indices
926
922
  self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
927
- self.kv_indptr_stride = self.kv_indptr.shape[1]
928
923
 
929
- def common_template(self, forward_batch: ForwardBatch, call_fn: int):
924
+ def common_template(
925
+ self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
926
+ ):
930
927
  num_seqs = forward_batch.batch_size
931
928
  bs = self.topk * num_seqs
932
929
  seq_lens_sum = forward_batch.seq_lens_sum
930
+
933
931
  self.generate_draft_decode_kv_indices[
934
932
  (self.speculative_num_steps, num_seqs, self.topk)
935
933
  ](
936
934
  forward_batch.req_pool_indices,
937
935
  forward_batch.req_to_token_pool.req_to_token,
938
936
  forward_batch.seq_lens,
939
- self.cuda_graph_kv_indices,
937
+ kv_indices_buffer,
940
938
  self.kv_indptr,
941
939
  forward_batch.positions,
942
940
  num_seqs,
943
941
  self.topk,
944
942
  self.pool_len,
945
- self.kv_indptr_stride,
943
+ kv_indices_buffer.shape[1],
946
944
  self.kv_indptr.shape[1],
947
945
  triton.next_power_of_2(num_seqs),
948
946
  triton.next_power_of_2(self.speculative_num_steps),
949
947
  triton.next_power_of_2(bs),
950
948
  )
949
+
951
950
  for i in range(self.speculative_num_steps):
952
951
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
953
- forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][
952
+ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
954
953
  : seq_lens_sum * self.topk + bs * (i + 1)
955
954
  ]
956
955
  call_fn(i, forward_batch)
957
956
 
958
957
  def init_forward_metadata(self, forward_batch: ForwardBatch):
958
+ kv_indices = torch.zeros(
959
+ (
960
+ self.speculative_num_steps,
961
+ forward_batch.batch_size * self.topk * self.max_context_len,
962
+ ),
963
+ dtype=torch.int32,
964
+ device="cuda",
965
+ )
966
+
959
967
  def call_fn(i, forward_batch):
960
968
  forward_batch.spec_info.kv_indptr = (
961
969
  forward_batch.spec_info.kv_indptr.clone()
@@ -965,7 +973,7 @@ class FlashInferMultiStepDraftBackend:
965
973
  )
966
974
  self.attn_backends[i].init_forward_metadata(forward_batch)
967
975
 
968
- self.common_template(forward_batch, call_fn)
976
+ self.common_template(forward_batch, kv_indices, call_fn)
969
977
 
970
978
  def init_cuda_graph_state(self, max_bs: int):
971
979
  self.cuda_graph_kv_indices = torch.zeros(
@@ -973,7 +981,6 @@ class FlashInferMultiStepDraftBackend:
973
981
  dtype=torch.int32,
974
982
  device="cuda",
975
983
  )
976
- self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
977
984
  for i in range(self.speculative_num_steps):
978
985
  self.attn_backends[i].init_cuda_graph_state(
979
986
  max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
@@ -995,7 +1002,7 @@ class FlashInferMultiStepDraftBackend:
995
1002
  ][0]
996
1003
  decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
997
1004
 
998
- self.common_template(forward_batch, call_fn)
1005
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
999
1006
 
1000
1007
  def init_forward_metadata_replay_cuda_graph(self, forward_batch):
1001
1008
  def call_fn(i, forward_batch):
@@ -1009,7 +1016,7 @@ class FlashInferMultiStepDraftBackend:
1009
1016
  spec_info=forward_batch.spec_info,
1010
1017
  )
1011
1018
 
1012
- self.common_template(forward_batch, call_fn)
1019
+ self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
1013
1020
 
1014
1021
 
1015
1022
  @triton.jit
@@ -1070,21 +1077,6 @@ def should_use_tensor_core(
1070
1077
  if env_override is not None:
1071
1078
  return env_override.lower() == "true"
1072
1079
 
1073
- # Try to use _grouped_size_compiled_for_decode_kernels if available
1074
- # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1075
- try:
1076
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1077
-
1078
- if not _grouped_size_compiled_for_decode_kernels(
1079
- num_attention_heads,
1080
- num_kv_heads,
1081
- ):
1082
- return True
1083
- else:
1084
- return False
1085
- except (ImportError, AttributeError):
1086
- pass
1087
-
1088
1080
  # Calculate GQA group size
1089
1081
  gqa_group_size = num_attention_heads // num_kv_heads
1090
1082
 
@@ -1114,6 +1106,7 @@ def fast_decode_plan(
1114
1106
  sm_scale: Optional[float] = None,
1115
1107
  rope_scale: Optional[float] = None,
1116
1108
  rope_theta: Optional[float] = None,
1109
+ **kwargs,
1117
1110
  ) -> None:
1118
1111
  """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
1119
1112
  batch_size = len(last_page_len)
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
18
  from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
19
 
20
20
  is_hip_flag = is_hip()
21
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
21
+
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
27
27
  int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
28
28
  )
29
29
 
30
+ _is_cuda = torch.cuda.is_available() and torch.version.cuda
31
+ _is_rocm = torch.cuda.is_available() and torch.version.hip
32
+
33
+ if _is_cuda:
34
+ from sgl_kernel import gelu_and_mul, silu_and_mul
35
+
36
+ if _is_cuda or _is_rocm:
37
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
38
+
30
39
 
31
40
  @triton.jit
32
41
  def fused_moe_kernel(
@@ -989,9 +998,15 @@ def fused_experts_impl(
989
998
  )
990
999
 
991
1000
  if activation == "silu":
992
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1001
+ if _is_cuda:
1002
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1003
+ else:
1004
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
993
1005
  elif activation == "gelu":
994
- ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1006
+ if _is_cuda:
1007
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1008
+ else:
1009
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
995
1010
  else:
996
1011
  raise ValueError(f"Unsupported activation: {activation=}")
997
1012
 
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 64,
7
+ "num_warps": 8,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 8,
56
+ "num_stages": 5
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 4
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 128,
76
+ "BLOCK_SIZE_N": 32,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 8,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 32,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 32,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 64,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 5
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 32,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 32,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 32,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 32,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 5
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 32,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 32,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }