sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/layers/activation.py +10 -5
  5. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  6. sglang/srt/layers/attention/triton_backend.py +71 -7
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  8. sglang/srt/layers/layernorm.py +1 -5
  9. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -11
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  20. sglang/srt/layers/moe/topk.py +4 -0
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/fp8_kernel.py +140 -2
  44. sglang/srt/layers/rotary_embedding.py +1 -3
  45. sglang/srt/layers/sampler.py +4 -4
  46. sglang/srt/lora/backend/__init__.py +8 -0
  47. sglang/srt/lora/backend/base_backend.py +95 -0
  48. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  49. sglang/srt/lora/backend/triton_backend.py +61 -0
  50. sglang/srt/lora/lora.py +127 -112
  51. sglang/srt/lora/lora_manager.py +50 -18
  52. sglang/srt/lora/triton_ops/__init__.py +5 -0
  53. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  54. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  55. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  56. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  57. sglang/srt/model_executor/forward_batch_info.py +58 -59
  58. sglang/srt/model_executor/model_runner.py +2 -2
  59. sglang/srt/models/qwen2_vl.py +1 -1
  60. sglang/srt/server_args.py +13 -2
  61. sglang/srt/speculative/build_eagle_tree.py +4 -2
  62. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  63. sglang/srt/speculative/eagle_utils.py +361 -372
  64. sglang/srt/speculative/eagle_worker.py +177 -45
  65. sglang/srt/utils.py +7 -0
  66. sglang/test/runners.py +2 -0
  67. sglang/version.py +1 -1
  68. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/METADATA +15 -6
  69. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/RECORD +72 -33
  70. sglang/srt/layers/custom_op_util.py +0 -25
  71. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/LICENSE +0 -0
  72. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/WHEEL +0 -0
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post2.dist-info}/top_level.txt +0 -0
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
49
49
  K_Buffer,
50
50
  V_Buffer,
51
51
  sm_scale,
52
- Req_to_tokens,
53
- B_req_idx,
54
- B_Seqlen,
52
+ kv_indptr,
53
+ kv_indices,
55
54
  Att_Out,
56
- stride_req_to_tokens_b,
57
55
  stride_qbs,
58
56
  stride_qh,
59
57
  stride_buf_kbs,
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
82
80
  offs_dv = tl.arange(0, BLOCK_DV)
83
81
  mask_d = offs_d < Lk
84
82
  mask_dv = offs_dv < Lv
85
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
86
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
83
+
84
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
85
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
87
86
 
88
87
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
89
88
  q = tl.load(Q + off_q, mask=mask_d, other=0.0)
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
100
99
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
101
100
  offs_n = start_n + tl.arange(0, BLOCK_N)
102
101
  kv_loc = tl.load(
103
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
102
+ kv_indices + cur_batch_kv_start_idx + offs_n,
104
103
  mask=offs_n < split_kv_end,
105
104
  other=0,
106
105
  )
@@ -173,19 +172,21 @@ def _decode_att_m_fwd(
173
172
  k_buffer,
174
173
  v_buffer,
175
174
  att_out,
176
- Req_to_tokens,
177
- B_req_idx,
178
- B_Seqlen,
175
+ kv_indptr,
176
+ kv_indices,
179
177
  num_kv_splits,
180
178
  sm_scale,
181
179
  logit_cap,
182
180
  ):
183
181
  BLOCK = 64
182
+ # [TODO] work around SGPR limit on MI3xx
183
+ if is_hip_:
184
+ BLOCK = 8
184
185
  NUM_KV_SPLITS = num_kv_splits
185
186
  Lk = k_buffer.shape[-1]
186
187
  Lv = v_buffer.shape[-1]
187
188
 
188
- batch, head_num = B_req_idx.shape[0], q.shape[1]
189
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
189
190
 
190
191
  grid = (batch, head_num, NUM_KV_SPLITS)
191
192
  kv_group_num = q.shape[1] // k_buffer.shape[1]
@@ -194,6 +195,8 @@ def _decode_att_m_fwd(
194
195
  num_warps = 4
195
196
  else:
196
197
  num_warps = 2
198
+ if is_hip_:
199
+ num_warps = 1
197
200
 
198
201
  BLOCK_DMODEL = triton.next_power_of_2(Lk)
199
202
  BLOCK_DV = triton.next_power_of_2(Lv)
@@ -203,11 +206,9 @@ def _decode_att_m_fwd(
203
206
  k_buffer,
204
207
  v_buffer,
205
208
  sm_scale,
206
- Req_to_tokens,
207
- B_req_idx,
208
- B_Seqlen,
209
+ kv_indptr,
210
+ kv_indices,
209
211
  att_out,
210
- Req_to_tokens.stride(0),
211
212
  q.stride(0),
212
213
  q.stride(1),
213
214
  k_buffer.stride(0),
@@ -236,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
236
237
  K_Buffer,
237
238
  V_Buffer,
238
239
  sm_scale,
239
- Req_to_tokens,
240
- B_req_idx,
241
- B_Seqlen,
240
+ kv_indptr,
241
+ kv_indices,
242
242
  Att_Out,
243
- stride_req_to_tokens_b,
244
243
  stride_qbs,
245
244
  stride_qh,
246
245
  stride_buf_kbs,
@@ -279,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
279
278
  offs_dv = tl.arange(0, BLOCK_DV)
280
279
  mask_d = offs_d < Lk
281
280
  mask_dv = offs_dv < Lv
282
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
283
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
281
+
282
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
283
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
284
284
 
285
285
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
286
286
  q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
@@ -307,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
307
307
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
308
308
  offs_n = start_n + tl.arange(0, BLOCK_N)
309
309
  kv_loc = tl.load(
310
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
310
+ kv_indices + cur_batch_kv_start_idx + offs_n,
311
311
  mask=offs_n < split_kv_end,
312
312
  other=0,
313
313
  )
@@ -395,9 +395,8 @@ def _decode_grouped_att_m_fwd(
395
395
  k_buffer,
396
396
  v_buffer,
397
397
  att_out,
398
- Req_to_tokens,
399
- B_req_idx,
400
- B_Seqlen,
398
+ kv_indptr,
399
+ kv_indices,
401
400
  num_kv_splits,
402
401
  sm_scale,
403
402
  logit_cap,
@@ -421,7 +420,7 @@ def _decode_grouped_att_m_fwd(
421
420
  BLOCK_DPE = 0
422
421
  BLOCK_DV = triton.next_power_of_2(Lv)
423
422
 
424
- batch, head_num = B_req_idx.shape[0], q.shape[1]
423
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
425
424
  kv_group_num = q.shape[1] // k_buffer.shape[1]
426
425
 
427
426
  BLOCK_H = 16
@@ -433,21 +432,21 @@ def _decode_grouped_att_m_fwd(
433
432
  )
434
433
 
435
434
  extra_kargs = {}
435
+ num_stages = 2
436
436
  if is_hip_:
437
437
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
438
438
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
439
- extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
439
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
440
+ num_stages = 1
440
441
 
441
442
  _fwd_grouped_kernel_stage1[grid](
442
443
  q,
443
444
  k_buffer,
444
445
  v_buffer,
445
446
  sm_scale,
446
- Req_to_tokens,
447
- B_req_idx,
448
- B_Seqlen,
447
+ kv_indptr,
448
+ kv_indices,
449
449
  att_out,
450
- Req_to_tokens.stride(0),
451
450
  q.stride(0),
452
451
  q.stride(1),
453
452
  k_buffer.stride(0),
@@ -467,7 +466,7 @@ def _decode_grouped_att_m_fwd(
467
466
  NUM_KV_SPLITS=NUM_KV_SPLITS,
468
467
  logit_cap=logit_cap,
469
468
  num_warps=4,
470
- num_stages=2,
469
+ num_stages=num_stages,
471
470
  Lk=Lk,
472
471
  Lv=Lv,
473
472
  **extra_kargs,
@@ -478,7 +477,7 @@ def _decode_grouped_att_m_fwd(
478
477
  def _fwd_kernel_stage2(
479
478
  Mid_O,
480
479
  O,
481
- B_Seqlen,
480
+ kv_indptr,
482
481
  stride_mid_ob,
483
482
  stride_mid_oh,
484
483
  stride_mid_os,
@@ -491,7 +490,9 @@ def _fwd_kernel_stage2(
491
490
  cur_batch = tl.program_id(0)
492
491
  cur_head = tl.program_id(1)
493
492
 
494
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
493
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
494
+ kv_indptr + cur_batch
495
+ )
495
496
 
496
497
  offs_d = tl.arange(0, BLOCK_DV)
497
498
  mask_d = offs_d < Lv
@@ -535,7 +536,7 @@ def _decode_softmax_reducev_fwd(
535
536
  q,
536
537
  o,
537
538
  v_buffer,
538
- b_seq_len,
539
+ kv_indptr,
539
540
  num_kv_splits,
540
541
  ):
541
542
  batch, head_num = q.shape[0], q.shape[1]
@@ -554,7 +555,7 @@ def _decode_softmax_reducev_fwd(
554
555
  _fwd_kernel_stage2[grid](
555
556
  logits,
556
557
  o,
557
- b_seq_len,
558
+ kv_indptr,
558
559
  logits.stride(0),
559
560
  logits.stride(1),
560
561
  logits.stride(2),
@@ -574,9 +575,8 @@ def decode_attention_fwd_normal(
574
575
  k_buffer,
575
576
  v_buffer,
576
577
  o,
577
- req_to_token,
578
- b_req_idx,
579
- b_seq_len,
578
+ kv_indptr,
579
+ kv_indices,
580
580
  attn_logits,
581
581
  num_kv_splits,
582
582
  sm_scale,
@@ -587,14 +587,13 @@ def decode_attention_fwd_normal(
587
587
  k_buffer,
588
588
  v_buffer,
589
589
  attn_logits,
590
- req_to_token,
591
- b_req_idx,
592
- b_seq_len,
590
+ kv_indptr,
591
+ kv_indices,
593
592
  num_kv_splits,
594
593
  sm_scale,
595
594
  logit_cap,
596
595
  )
597
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
596
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
598
597
 
599
598
 
600
599
  def decode_attention_fwd_grouped(
@@ -602,9 +601,8 @@ def decode_attention_fwd_grouped(
602
601
  k_buffer,
603
602
  v_buffer,
604
603
  o,
605
- req_to_token,
606
- b_req_idx,
607
- b_seq_len,
604
+ kv_indptr,
605
+ kv_indices,
608
606
  attn_logits,
609
607
  num_kv_splits,
610
608
  sm_scale,
@@ -615,14 +613,13 @@ def decode_attention_fwd_grouped(
615
613
  k_buffer,
616
614
  v_buffer,
617
615
  attn_logits,
618
- req_to_token,
619
- b_req_idx,
620
- b_seq_len,
616
+ kv_indptr,
617
+ kv_indices,
621
618
  num_kv_splits,
622
619
  sm_scale,
623
620
  logit_cap,
624
621
  )
625
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
622
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
626
623
 
627
624
 
628
625
  def decode_attention_fwd(
@@ -630,9 +627,8 @@ def decode_attention_fwd(
630
627
  k_buffer,
631
628
  v_buffer,
632
629
  o,
633
- req_to_token,
634
- b_req_idx,
635
- b_seq_len,
630
+ kv_indptr,
631
+ kv_indices,
636
632
  attn_logits,
637
633
  num_kv_splits,
638
634
  sm_scale,
@@ -648,9 +644,8 @@ def decode_attention_fwd(
648
644
  k_buffer,
649
645
  v_buffer,
650
646
  o,
651
- req_to_token,
652
- b_req_idx,
653
- b_seq_len,
647
+ kv_indptr,
648
+ kv_indices,
654
649
  attn_logits,
655
650
  num_kv_splits,
656
651
  sm_scale,
@@ -663,9 +658,8 @@ def decode_attention_fwd(
663
658
  k_buffer,
664
659
  v_buffer,
665
660
  o,
666
- req_to_token,
667
- b_req_idx,
668
- b_seq_len,
661
+ kv_indptr,
662
+ kv_indices,
669
663
  attn_logits,
670
664
  num_kv_splits,
671
665
  sm_scale,
@@ -29,14 +29,11 @@ if is_cuda_available():
29
29
  rmsnorm,
30
30
  )
31
31
 
32
- from vllm.model_executor.custom_op import CustomOp
33
-
34
- from sglang.srt.layers.custom_op_util import register_custom_op
32
+ from sglang.srt.custom_op import CustomOp
35
33
 
36
34
  logger = logging.getLogger(__name__)
37
35
 
38
36
 
39
- @register_custom_op("sglang_rmsnorm")
40
37
  class RMSNorm(CustomOp):
41
38
  def __init__(
42
39
  self,
@@ -79,7 +76,6 @@ class RMSNorm(CustomOp):
79
76
  return x, residual
80
77
 
81
78
 
82
- @register_custom_op("sglang_gemma_rmsnorm")
83
79
  class GemmaRMSNorm(CustomOp):
84
80
  def __init__(
85
81
  self,
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
4
4
  import torch
5
5
  from torch.nn import Module
6
6
  from vllm import _custom_ops as ops
7
- from vllm.model_executor.custom_op import CustomOp
8
7
 
8
+ from sglang.srt.custom_op import CustomOp
9
9
  from sglang.srt.distributed import (
10
10
  get_tensor_model_parallel_rank,
11
11
  get_tensor_model_parallel_world_size,
12
12
  )
13
- from sglang.srt.layers.custom_op_util import register_custom_op
14
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
15
14
  grouped_gemm_triton,
16
15
  post_reorder_triton_kernel,
@@ -407,7 +406,6 @@ class EPMoE(torch.nn.Module):
407
406
  param_data[expert_id] = loaded_weight
408
407
 
409
408
 
410
- @register_custom_op("sglang_unquantized_ep_moe")
411
409
  class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
412
410
  def create_weights(
413
411
  self,
@@ -0,0 +1,164 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 32,
4
+ "BLOCK_SIZE_N": 32,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 2,
9
+ "waves_per_eu": 0
10
+ },
11
+ "2": {
12
+ "BLOCK_SIZE_M": 32,
13
+ "BLOCK_SIZE_N": 64,
14
+ "BLOCK_SIZE_K": 128,
15
+ "GROUP_SIZE_M": 1,
16
+ "num_warps": 4,
17
+ "num_stages": 2,
18
+ "waves_per_eu": 0
19
+ },
20
+ "4": {
21
+ "BLOCK_SIZE_M": 64,
22
+ "BLOCK_SIZE_N": 64,
23
+ "BLOCK_SIZE_K": 128,
24
+ "GROUP_SIZE_M": 16,
25
+ "num_warps": 4,
26
+ "num_stages": 2,
27
+ "waves_per_eu": 0
28
+ },
29
+ "8": {
30
+ "BLOCK_SIZE_M": 32,
31
+ "BLOCK_SIZE_N": 128,
32
+ "BLOCK_SIZE_K": 128,
33
+ "GROUP_SIZE_M": 32,
34
+ "num_warps": 4,
35
+ "num_stages": 2,
36
+ "waves_per_eu": 0
37
+ },
38
+ "16": {
39
+ "BLOCK_SIZE_M": 32,
40
+ "BLOCK_SIZE_N": 128,
41
+ "BLOCK_SIZE_K": 128,
42
+ "GROUP_SIZE_M": 1,
43
+ "num_warps": 4,
44
+ "num_stages": 2,
45
+ "waves_per_eu": 0
46
+ },
47
+ "24": {
48
+ "BLOCK_SIZE_M": 32,
49
+ "BLOCK_SIZE_N": 128,
50
+ "BLOCK_SIZE_K": 128,
51
+ "GROUP_SIZE_M": 4,
52
+ "num_warps": 4,
53
+ "num_stages": 2,
54
+ "waves_per_eu": 0
55
+ },
56
+ "32": {
57
+ "BLOCK_SIZE_M": 32,
58
+ "BLOCK_SIZE_N": 128,
59
+ "BLOCK_SIZE_K": 128,
60
+ "GROUP_SIZE_M": 8,
61
+ "num_warps": 4,
62
+ "num_stages": 2,
63
+ "waves_per_eu": 0
64
+ },
65
+ "48": {
66
+ "BLOCK_SIZE_M": 32,
67
+ "BLOCK_SIZE_N": 128,
68
+ "BLOCK_SIZE_K": 128,
69
+ "GROUP_SIZE_M": 4,
70
+ "num_warps": 4,
71
+ "num_stages": 2,
72
+ "waves_per_eu": 0
73
+ },
74
+ "64": {
75
+ "BLOCK_SIZE_M": 256,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 2,
81
+ "waves_per_eu": 0
82
+ },
83
+ "96": {
84
+ "BLOCK_SIZE_M": 32,
85
+ "BLOCK_SIZE_N": 128,
86
+ "BLOCK_SIZE_K": 128,
87
+ "GROUP_SIZE_M": 8,
88
+ "num_warps": 4,
89
+ "num_stages": 2,
90
+ "waves_per_eu": 0
91
+ },
92
+ "128": {
93
+ "BLOCK_SIZE_M": 32,
94
+ "BLOCK_SIZE_N": 16,
95
+ "BLOCK_SIZE_K": 128,
96
+ "GROUP_SIZE_M": 4,
97
+ "num_warps": 4,
98
+ "num_stages": 2,
99
+ "waves_per_eu": 0
100
+ },
101
+ "256": {
102
+ "BLOCK_SIZE_M": 64,
103
+ "BLOCK_SIZE_N": 16,
104
+ "BLOCK_SIZE_K": 128,
105
+ "GROUP_SIZE_M": 1,
106
+ "num_warps": 4,
107
+ "num_stages": 2,
108
+ "waves_per_eu": 0
109
+ },
110
+ "512": {
111
+ "BLOCK_SIZE_M": 64,
112
+ "BLOCK_SIZE_N": 64,
113
+ "BLOCK_SIZE_K": 128,
114
+ "GROUP_SIZE_M": 32,
115
+ "num_warps": 4,
116
+ "num_stages": 2,
117
+ "waves_per_eu": 0
118
+ },
119
+ "1024": {
120
+ "BLOCK_SIZE_M": 64,
121
+ "BLOCK_SIZE_N": 64,
122
+ "BLOCK_SIZE_K": 128,
123
+ "GROUP_SIZE_M": 4,
124
+ "num_warps": 8,
125
+ "num_stages": 2,
126
+ "waves_per_eu": 0
127
+ },
128
+ "1536": {
129
+ "BLOCK_SIZE_M": 64,
130
+ "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_K": 128,
132
+ "GROUP_SIZE_M": 8,
133
+ "num_warps": 4,
134
+ "num_stages": 2,
135
+ "waves_per_eu": 0
136
+ },
137
+ "2048": {
138
+ "BLOCK_SIZE_M": 32,
139
+ "BLOCK_SIZE_N": 64,
140
+ "BLOCK_SIZE_K": 128,
141
+ "GROUP_SIZE_M": 1,
142
+ "num_warps": 4,
143
+ "num_stages": 2,
144
+ "waves_per_eu": 0
145
+ },
146
+ "3072": {
147
+ "BLOCK_SIZE_M": 32,
148
+ "BLOCK_SIZE_N": 128,
149
+ "BLOCK_SIZE_K": 128,
150
+ "GROUP_SIZE_M": 1,
151
+ "num_warps": 4,
152
+ "num_stages": 2,
153
+ "waves_per_eu": 0
154
+ },
155
+ "4096": {
156
+ "BLOCK_SIZE_M": 64,
157
+ "BLOCK_SIZE_N": 128,
158
+ "BLOCK_SIZE_K": 64,
159
+ "GROUP_SIZE_M": 4,
160
+ "num_warps": 4,
161
+ "num_stages": 2,
162
+ "waves_per_eu": 0
163
+ }
164
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 16,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 256,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 8,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 256,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 8,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 5
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 256,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 32,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 4
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 16,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 16,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 16,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 32,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 32,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }