sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -46,11 +46,11 @@ def _fwd_kernel(
46
46
  O_Extend,
47
47
  K_Buffer,
48
48
  V_Buffer,
49
- Req_to_tokens,
50
- B_req_idx,
51
- B_Seq_Len,
52
- B_Start_Loc_Extend,
53
- B_Seq_Len_Extend,
49
+ qo_indptr,
50
+ kv_indptr,
51
+ kv_indices,
52
+ mask_ptr,
53
+ mask_offsets,
54
54
  sm_scale,
55
55
  kv_group_num,
56
56
  stride_qbs,
@@ -65,7 +65,6 @@ def _fwd_kernel(
65
65
  stride_buf_kh,
66
66
  stride_buf_vbs,
67
67
  stride_buf_vh,
68
- stride_req_to_tokens_b,
69
68
  logit_cap: tl.constexpr,
70
69
  Lq: tl.constexpr,
71
70
  Lv: tl.constexpr,
@@ -74,19 +73,21 @@ def _fwd_kernel(
74
73
  BLOCK_DV: tl.constexpr,
75
74
  BLOCK_M: tl.constexpr,
76
75
  BLOCK_N: tl.constexpr,
76
+ USE_CUSTOM_MASK: tl.constexpr,
77
77
  ):
78
78
  cur_seq = tl.program_id(0)
79
79
  cur_head = tl.program_id(1)
80
80
  cur_block_m = tl.program_id(2)
81
81
  cur_kv_head = cur_head // kv_group_num
82
82
 
83
- cur_seq_len = tl.load(B_Seq_Len + cur_seq)
84
- cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
85
- cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
83
+ cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
84
+ cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
85
+ cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
86
+ cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
87
+ cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
86
88
 
87
- cur_seq_prefix_start_in_loc = 0
88
- cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
89
- cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
89
+ if USE_CUSTOM_MASK:
90
+ cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
90
91
 
91
92
  offs_d = tl.arange(0, BLOCK_DMODEL)
92
93
  offs_dv = tl.arange(0, BLOCK_DV)
@@ -97,7 +98,7 @@ def _fwd_kernel(
97
98
  mask_dv = offs_dv < Lv
98
99
 
99
100
  offs_q = (
100
- (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
101
+ (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
101
102
  * stride_qbs
102
103
  + cur_head * stride_qh
103
104
  + offs_d[None, :]
@@ -109,7 +110,7 @@ def _fwd_kernel(
109
110
  if BLOCK_DPE > 0:
110
111
  offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
111
112
  offs_qpe = (
112
- (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
113
+ (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
113
114
  * stride_qbs
114
115
  + cur_head * stride_qh
115
116
  + offs_dpe[None, :]
@@ -126,10 +127,9 @@ def _fwd_kernel(
126
127
  for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
127
128
  start_n = tl.multiple_of(start_n, BLOCK_N)
128
129
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
129
- offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
130
- cur_seq_prefix_start_in_loc + start_n + offs_n
130
+ offs_kv_loc = tl.load(
131
+ kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
131
132
  )
132
- offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
133
133
 
134
134
  # load k in transposed way
135
135
  offs_buf_k = (
@@ -159,7 +159,20 @@ def _fwd_kernel(
159
159
  if logit_cap > 0:
160
160
  qk = logit_cap * tanh(qk / logit_cap)
161
161
 
162
- qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
162
+ if USE_CUSTOM_MASK:
163
+ custom_mask = tl.load(
164
+ mask_ptr
165
+ + cur_seq_mask_start_idx
166
+ + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
167
+ + start_n
168
+ + offs_n[None, :],
169
+ mask=(mask_m[:, None] & mask_n[None, :]),
170
+ other=0,
171
+ )
172
+ custom_mask &= mask_m[:, None] & mask_n[None, :]
173
+ qk = tl.where(custom_mask, qk, float("-inf"))
174
+ else:
175
+ qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
163
176
 
164
177
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
165
178
  re_scale = tl.exp(e_max - n_e_max)
@@ -179,7 +192,7 @@ def _fwd_kernel(
179
192
 
180
193
  e_max = n_e_max
181
194
 
182
- # stage 2: compute the trianlge part
195
+ # stage 2: compute the triangle part
183
196
 
184
197
  cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
185
198
  for start_n in range(0, cur_block_m_end, BLOCK_N):
@@ -188,7 +201,7 @@ def _fwd_kernel(
188
201
 
189
202
  # load k in transposed way
190
203
  offs_k = (
191
- (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
204
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
192
205
  + cur_kv_head * stride_kh
193
206
  + offs_d[:, None]
194
207
  )
@@ -199,8 +212,7 @@ def _fwd_kernel(
199
212
  qk = tl.dot(q, k, out_dtype=tl.float32)
200
213
  if BLOCK_DPE > 0:
201
214
  offs_kpe = (
202
- (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
203
- * stride_kbs
215
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
204
216
  + cur_kv_head * stride_kh
205
217
  + offs_dpe[:, None]
206
218
  )
@@ -216,11 +228,25 @@ def _fwd_kernel(
216
228
  if logit_cap > 0:
217
229
  qk = logit_cap * tanh(qk / logit_cap)
218
230
 
219
- mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
220
- start_n + offs_n[None, :]
221
- )
222
- mask_causual &= mask_m[:, None] & mask_n[None, :]
223
- qk = tl.where(mask_causual, qk, float("-inf"))
231
+ if USE_CUSTOM_MASK:
232
+ custom_mask = tl.load(
233
+ mask_ptr
234
+ + cur_seq_mask_start_idx
235
+ + (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
236
+ + cur_seq_len_prefix
237
+ + start_n
238
+ + offs_n[None, :],
239
+ mask=(mask_m[:, None] & mask_n[None, :]),
240
+ other=0,
241
+ )
242
+ custom_mask &= mask_m[:, None] & mask_n[None, :]
243
+ qk = tl.where(custom_mask, qk, float("-inf"))
244
+ else:
245
+ mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
246
+ start_n + offs_n[None, :]
247
+ )
248
+ mask_causual &= mask_m[:, None] & mask_n[None, :]
249
+ qk = tl.where(mask_causual, qk, float("-inf"))
224
250
 
225
251
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
226
252
  re_scale = tl.exp(e_max - n_e_max)
@@ -228,7 +254,7 @@ def _fwd_kernel(
228
254
  deno = deno * re_scale + tl.sum(p, 1)
229
255
 
230
256
  offs_v = (
231
- (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
257
+ (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
232
258
  + cur_kv_head * stride_vh
233
259
  + offs_dv[None, :]
234
260
  )
@@ -241,7 +267,7 @@ def _fwd_kernel(
241
267
  e_max = n_e_max
242
268
 
243
269
  offs_o = (
244
- (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
270
+ (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
245
271
  * stride_obs
246
272
  + cur_head * stride_oh
247
273
  + offs_dv[None, :]
@@ -258,11 +284,11 @@ def extend_attention_fwd(
258
284
  o_extend,
259
285
  k_buffer,
260
286
  v_buffer,
261
- req_to_tokens,
262
- b_req_idx,
263
- b_seq_len,
264
- b_seq_len_extend,
265
- b_start_loc_extend,
287
+ qo_indptr,
288
+ kv_indptr,
289
+ kv_indices,
290
+ custom_mask,
291
+ mask_offsets,
266
292
  max_len_extend,
267
293
  sm_scale=None,
268
294
  logit_cap=0.0,
@@ -315,15 +341,17 @@ def extend_attention_fwd(
315
341
  num_warps = 4 if Lk <= 64 else 8
316
342
 
317
343
  sm_scale = sm_scale or 1.0 / (Lq**0.5)
318
- batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
344
+ batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
319
345
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
320
346
 
347
+ USE_CUSTOM_MASK = custom_mask is not None
348
+
321
349
  grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
322
350
  num_stages = 1
323
351
 
324
352
  extra_kargs = {}
325
353
  if is_hip_:
326
- extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
354
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
327
355
 
328
356
  _fwd_kernel[grid](
329
357
  q_extend,
@@ -332,11 +360,11 @@ def extend_attention_fwd(
332
360
  o_extend,
333
361
  k_buffer,
334
362
  v_buffer,
335
- req_to_tokens,
336
- b_req_idx,
337
- b_seq_len,
338
- b_start_loc_extend,
339
- b_seq_len_extend,
363
+ qo_indptr,
364
+ kv_indptr,
365
+ kv_indices,
366
+ custom_mask,
367
+ mask_offsets,
340
368
  sm_scale,
341
369
  kv_group_num,
342
370
  q_extend.stride(0),
@@ -351,7 +379,6 @@ def extend_attention_fwd(
351
379
  k_buffer.stride(1),
352
380
  v_buffer.stride(0),
353
381
  v_buffer.stride(1),
354
- req_to_tokens.stride(0),
355
382
  logit_cap=logit_cap,
356
383
  BLOCK_DMODEL=BLOCK_DMODEL,
357
384
  BLOCK_DPE=BLOCK_DPE,
@@ -360,6 +387,7 @@ def extend_attention_fwd(
360
387
  BLOCK_N=BLOCK_N,
361
388
  Lq=Lq,
362
389
  Lv=Lv,
390
+ USE_CUSTOM_MASK=USE_CUSTOM_MASK,
363
391
  num_warps=num_warps,
364
392
  num_stages=num_stages,
365
393
  **extra_kargs,
@@ -29,14 +29,11 @@ if is_cuda_available():
29
29
  rmsnorm,
30
30
  )
31
31
 
32
- from vllm.model_executor.custom_op import CustomOp
33
-
34
- from sglang.srt.layers.custom_op_util import register_custom_op
32
+ from sglang.srt.custom_op import CustomOp
35
33
 
36
34
  logger = logging.getLogger(__name__)
37
35
 
38
36
 
39
- @register_custom_op("sglang_rmsnorm")
40
37
  class RMSNorm(CustomOp):
41
38
  def __init__(
42
39
  self,
@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
79
76
  return x, residual
80
77
 
81
78
 
82
- @register_custom_op("sglang_gemma_rmsnorm")
83
79
  class GemmaRMSNorm(CustomOp):
84
80
  def __init__(
85
81
  self,
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
  from vllm import _custom_ops as ops
7
- from vllm.model_executor.custom_op import CustomOp
8
7
 
8
+ from sglang.srt.custom_op import CustomOp
9
9
  from sglang.srt.distributed import (
10
10
  get_tensor_model_parallel_rank,
11
11
  get_tensor_model_parallel_world_size,
12
12
  )
13
- from sglang.srt.layers.custom_op_util import register_custom_op
14
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
15
14
  grouped_gemm_triton,
16
15
  post_reorder_triton_kernel,
@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
407
406
  param_data[expert_id] = loaded_weight
408
407
 
409
408
 
410
- @register_custom_op("sglang_unquantized_ep_moe")
411
409
  class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
412
410
  def create_weights(
413
411
  self,
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 32,
13
+ "BLOCK_SIZE_N": 64,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 64,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 16,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 32,
31
+ "BLOCK_SIZE_N": 128,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 32,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 32,
40
+ "BLOCK_SIZE_N": 128,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 128,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 4,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 32,
58
+ "BLOCK_SIZE_N": 128,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 8,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 32,
67
+ "BLOCK_SIZE_N": 128,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 4,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 256,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 128,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 8,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 32,
94
+ "BLOCK_SIZE_N": 16,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 1,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 64,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 32,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 64,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 8,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 32,
139
+ "BLOCK_SIZE_N": 64,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 32,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 1,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 128,
158
+ "BLOCK_SIZE_K": 64,
159
+ "GROUP_SIZE_M": 4,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 256,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 8,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 256,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 8,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 16,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 16,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 16,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }