sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -68,9 +68,6 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
68
68
  self.num_q_heads = (
69
69
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
70
70
  )
71
- self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
72
- get_attention_tp_size()
73
- )
74
71
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
75
72
  self.num_local_heads = (
76
73
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -111,8 +108,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
111
108
  )
112
109
  mla_metadata, num_splits = get_mla_metadata(
113
110
  forward_batch.seq_lens.to(torch.int32),
114
- Q_LEN * self.num_q_heads // self.num_kv_heads,
115
- self.num_kv_heads,
111
+ Q_LEN * self.num_q_heads,
112
+ 1,
116
113
  )
117
114
  self.forward_metadata = FlashMLADecodeMetadata(
118
115
  mla_metadata,
@@ -141,8 +138,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
141
138
 
142
139
  self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
143
140
  torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
144
- Q_LEN * self.num_q_heads // self.num_kv_heads,
145
- self.num_kv_heads,
141
+ Q_LEN * self.num_q_heads,
142
+ 1,
146
143
  )
147
144
  self.cuda_graph_kv_indices = cuda_graph_kv_indices
148
145
 
@@ -171,8 +168,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
171
168
  )
172
169
  mla_metadata, num_splits = get_mla_metadata(
173
170
  seq_lens.to(torch.int32),
174
- Q_LEN * self.num_q_heads // self.num_kv_heads,
175
- self.num_kv_heads,
171
+ Q_LEN * self.num_q_heads,
172
+ 1,
176
173
  )
177
174
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
178
175
  self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -221,8 +218,8 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
221
218
  )
222
219
  mla_metadata, num_splits = get_mla_metadata(
223
220
  seq_lens.to(torch.int32),
224
- Q_LEN * self.num_q_heads // self.num_kv_heads,
225
- self.num_kv_heads,
221
+ Q_LEN * self.num_q_heads,
222
+ 1,
226
223
  )
227
224
  self.cuda_graph_mla_metadata.copy_(mla_metadata)
228
225
  self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
@@ -3,10 +3,10 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
  from sglang.srt.managers.schedule_batch import global_server_args_dict
6
- from sglang.srt.utils import is_hip
6
+ from sglang.srt.utils import is_cuda, is_hip
7
7
 
8
- is_cuda_available = torch.cuda.is_available()
9
- if is_cuda_available:
8
+ _is_cuda = is_cuda()
9
+ if _is_cuda:
10
10
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
11
11
 
12
12
  _is_hip = is_hip()
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
1037
1037
  num_warps = 4
1038
1038
 
1039
1039
  else:
1040
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
1040
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
1041
1041
  if Lq <= 256:
1042
1042
  BLOCK_M, BLOCK_N = (128, 64)
1043
1043
  else:
1044
1044
  BLOCK_M, BLOCK_N = (32, 64)
1045
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
1045
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
1046
1046
  if Lq <= 128:
1047
1047
  BLOCK_M, BLOCK_N = (128, 128)
1048
1048
  elif Lq <= 256:
@@ -23,10 +23,10 @@ import triton.language as tl
23
23
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
24
24
  context_attention_fwd,
25
25
  )
26
- from sglang.srt.utils import is_hip
26
+ from sglang.srt.utils import is_cuda, is_hip
27
27
 
28
- is_cuda_available = torch.cuda.is_available()
29
- if is_cuda_available:
28
+ _is_cuda = is_cuda()
29
+ if _is_cuda:
30
30
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
31
31
 
32
32
  _is_hip = is_hip()
@@ -345,12 +345,12 @@ def extend_attention_fwd(
345
345
  num_warps = 4
346
346
 
347
347
  else:
348
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
348
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
349
349
  if Lq <= 256:
350
350
  BLOCK_M, BLOCK_N = (128, 64)
351
351
  else:
352
352
  BLOCK_M, BLOCK_N = (32, 64)
353
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
353
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
354
354
  # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
355
355
  if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
356
356
  if Lq <= 128:
@@ -22,8 +22,12 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- is_cuda_available = torch.cuda.is_available()
26
- if is_cuda_available:
25
+ from sglang.srt.utils import is_cuda, is_hip
26
+
27
+ _is_cuda = is_cuda()
28
+ _is_hip = is_hip()
29
+
30
+ if _is_cuda or _is_hip:
27
31
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
32
 
29
33
 
@@ -172,7 +176,7 @@ def context_attention_fwd(
172
176
  b_seq_len: [b]
173
177
  out: [b * s, head, head_dim]
174
178
  """
175
- if is_cuda_available and CUDA_CAPABILITY[0] > 8:
179
+ if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
176
180
  BLOCK = 128
177
181
  else:
178
182
  BLOCK = 64
@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
271
271
  Returns:
272
272
  [b * s, h, head_size]
273
273
  """
274
+ if self.flatten_batch:
275
+ assert bsz == 1, "flatten_batch is True, bsz must be 1"
274
276
 
275
277
  s = q.shape[0] // bsz
276
278
 
@@ -143,7 +143,7 @@ def memcpy_triton_kernel(
143
143
  src_ptr,
144
144
  offset_ptr,
145
145
  sz_ptr,
146
- offset_src,
146
+ offset_src: tl.constexpr,
147
147
  chunk_size, # multiplied for offset and sz
148
148
  BLOCK_SIZE: tl.constexpr,
149
149
  ):
@@ -20,9 +20,10 @@ import torch
20
20
  import torch.nn as nn
21
21
 
22
22
  from sglang.srt.custom_op import CustomOp
23
- from sglang.srt.utils import is_cuda_available
23
+ from sglang.srt.utils import is_cuda, is_hip
24
24
 
25
- _is_cuda = is_cuda_available()
25
+ _is_cuda = is_cuda()
26
+ _is_hip = is_hip()
26
27
 
27
28
  if _is_cuda:
28
29
  from sgl_kernel import (
@@ -32,6 +33,8 @@ if _is_cuda:
32
33
  rmsnorm,
33
34
  )
34
35
 
36
+ if _is_hip:
37
+ from vllm._custom_ops import fused_add_rms_norm, rms_norm
35
38
 
36
39
  logger = logging.getLogger(__name__)
37
40
 
@@ -46,23 +49,49 @@ class RMSNorm(CustomOp):
46
49
  self.weight = nn.Parameter(torch.ones(hidden_size))
47
50
  self.variance_epsilon = eps
48
51
 
52
+ def forward(self, *args, **kwargs):
53
+ if torch.compiler.is_compiling():
54
+ return self.forward_native(*args, **kwargs)
55
+ if _is_cuda:
56
+ return self.forward_cuda(*args, **kwargs)
57
+ elif _is_hip:
58
+ return self.forward_hip(*args, **kwargs)
59
+ else:
60
+ return self.forward_native(*args, **kwargs)
61
+
49
62
  def forward_cuda(
50
63
  self,
51
64
  x: torch.Tensor,
52
65
  residual: Optional[torch.Tensor] = None,
53
66
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
54
-
55
67
  if residual is not None:
56
68
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
57
69
  return x, residual
58
70
  out = rmsnorm(x, self.weight.data, self.variance_epsilon)
59
71
  return out
60
72
 
73
+ def forward_hip(
74
+ self,
75
+ x: torch.Tensor,
76
+ residual: Optional[torch.Tensor] = None,
77
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
78
+ if not x.is_contiguous():
79
+ # NOTE: Romove this if aiter kernel supports discontinuous input
80
+ x = x.contiguous()
81
+ if residual is not None:
82
+ fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
83
+ return x, residual
84
+ out = torch.empty_like(x)
85
+ rms_norm(out, x, self.weight.data, self.variance_epsilon)
86
+ return out
87
+
61
88
  def forward_native(
62
89
  self,
63
90
  x: torch.Tensor,
64
91
  residual: Optional[torch.Tensor] = None,
65
92
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
93
+ if not x.is_contiguous():
94
+ x = x.contiguous()
66
95
  orig_dtype = x.dtype
67
96
  x = x.to(torch.float32)
68
97
  if residual is not None:
@@ -88,6 +117,14 @@ class GemmaRMSNorm(CustomOp):
88
117
  self.weight = nn.Parameter(torch.zeros(hidden_size))
89
118
  self.variance_epsilon = eps
90
119
 
120
+ def forward(self, *args, **kwargs):
121
+ if torch.compiler.is_compiling():
122
+ return self.forward_native(*args, **kwargs)
123
+ if _is_cuda:
124
+ return self.forward_cuda(*args, **kwargs)
125
+ else:
126
+ return self.forward_native(*args, **kwargs)
127
+
91
128
  def forward_native(
92
129
  self,
93
130
  x: torch.Tensor,
@@ -139,8 +176,8 @@ class Gemma3RMSNorm(nn.Module):
139
176
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
140
177
 
141
178
 
142
- if not _is_cuda:
179
+ if not (_is_cuda or _is_hip):
143
180
  logger.info(
144
- "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
181
+ "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
145
182
  )
146
183
  from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -335,13 +335,13 @@ class LogitsProcessor(nn.Module):
335
335
  aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
336
336
  hidden_states_to_store = (
337
337
  aux_pruned_states[sample_indices]
338
- if sample_indices
338
+ if sample_indices is not None
339
339
  else aux_pruned_states
340
340
  )
341
341
  else:
342
342
  hidden_states_to_store = (
343
343
  pruned_states[sample_indices]
344
- if sample_indices
344
+ if sample_indices is not None
345
345
  else pruned_states
346
346
  )
347
347
  else:
@@ -802,6 +802,7 @@ class DeepEPMoE(EPMoE):
802
802
  correction_bias: Optional[torch.Tensor] = None,
803
803
  custom_routing_function: Optional[Callable] = None,
804
804
  activation: str = "silu",
805
+ routed_scaling_factor: Optional[float] = None,
805
806
  deepep_mode: DeepEPMode = DeepEPMode.auto,
806
807
  ):
807
808
  super().__init__(
@@ -820,6 +821,7 @@ class DeepEPMoE(EPMoE):
820
821
  correction_bias,
821
822
  custom_routing_function,
822
823
  activation,
824
+ routed_scaling_factor,
823
825
  )
824
826
  self.deepep_mode = deepep_mode
825
827
  if self.deepep_mode.enable_low_latency():
@@ -8,6 +8,7 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
+ from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
11
12
  from sglang.srt.layers.moe.topk import select_experts
12
13
 
13
14
 
@@ -30,7 +31,7 @@ def fused_moe_forward_native(
30
31
  ) -> torch.Tensor:
31
32
 
32
33
  if apply_router_weight_on_input:
33
- raise NotImplementedError
34
+ raise NotImplementedError()
34
35
 
35
36
  topk_weights, topk_ids = select_experts(
36
37
  hidden_states=x,
@@ -75,9 +76,6 @@ def moe_forward_native(
75
76
  activation: str = "silu",
76
77
  routed_scaling_factor: Optional[float] = None,
77
78
  ) -> torch.Tensor:
78
-
79
- from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
80
-
81
79
  topk_weights, topk_ids = select_experts(
82
80
  hidden_states=x,
83
81
  router_logits=router_logits,
@@ -1,102 +1,102 @@
1
1
  {
2
2
  "1": {
3
- "BLOCK_SIZE_M": 64,
3
+ "BLOCK_SIZE_M": 16,
4
4
  "BLOCK_SIZE_N": 64,
5
5
  "BLOCK_SIZE_K": 128,
6
- "GROUP_SIZE_M": 16,
6
+ "GROUP_SIZE_M": 1,
7
7
  "num_warps": 4,
8
8
  "num_stages": 4
9
9
  },
10
10
  "2": {
11
- "BLOCK_SIZE_M": 64,
12
- "BLOCK_SIZE_N": 32,
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
13
  "BLOCK_SIZE_K": 128,
14
- "GROUP_SIZE_M": 1,
14
+ "GROUP_SIZE_M": 16,
15
15
  "num_warps": 4,
16
- "num_stages": 3
16
+ "num_stages": 4
17
17
  },
18
18
  "4": {
19
- "BLOCK_SIZE_M": 64,
20
- "BLOCK_SIZE_N": 64,
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
21
  "BLOCK_SIZE_K": 128,
22
- "GROUP_SIZE_M": 1,
22
+ "GROUP_SIZE_M": 16,
23
23
  "num_warps": 4,
24
24
  "num_stages": 4
25
25
  },
26
26
  "8": {
27
- "BLOCK_SIZE_M": 64,
27
+ "BLOCK_SIZE_M": 16,
28
28
  "BLOCK_SIZE_N": 128,
29
29
  "BLOCK_SIZE_K": 128,
30
30
  "GROUP_SIZE_M": 32,
31
31
  "num_warps": 4,
32
- "num_stages": 3
32
+ "num_stages": 4
33
33
  },
34
34
  "16": {
35
- "BLOCK_SIZE_M": 64,
35
+ "BLOCK_SIZE_M": 16,
36
36
  "BLOCK_SIZE_N": 128,
37
37
  "BLOCK_SIZE_K": 128,
38
- "GROUP_SIZE_M": 16,
38
+ "GROUP_SIZE_M": 1,
39
39
  "num_warps": 4,
40
40
  "num_stages": 3
41
41
  },
42
42
  "24": {
43
- "BLOCK_SIZE_M": 64,
43
+ "BLOCK_SIZE_M": 16,
44
44
  "BLOCK_SIZE_N": 128,
45
45
  "BLOCK_SIZE_K": 128,
46
- "GROUP_SIZE_M": 16,
46
+ "GROUP_SIZE_M": 1,
47
47
  "num_warps": 4,
48
- "num_stages": 3
48
+ "num_stages": 4
49
49
  },
50
50
  "32": {
51
- "BLOCK_SIZE_M": 64,
51
+ "BLOCK_SIZE_M": 16,
52
52
  "BLOCK_SIZE_N": 128,
53
53
  "BLOCK_SIZE_K": 128,
54
- "GROUP_SIZE_M": 32,
54
+ "GROUP_SIZE_M": 16,
55
55
  "num_warps": 4,
56
- "num_stages": 3
56
+ "num_stages": 5
57
57
  },
58
58
  "48": {
59
- "BLOCK_SIZE_M": 64,
59
+ "BLOCK_SIZE_M": 16,
60
60
  "BLOCK_SIZE_N": 128,
61
61
  "BLOCK_SIZE_K": 128,
62
- "GROUP_SIZE_M": 32,
62
+ "GROUP_SIZE_M": 64,
63
63
  "num_warps": 4,
64
- "num_stages": 3
64
+ "num_stages": 4
65
65
  },
66
66
  "64": {
67
- "BLOCK_SIZE_M": 64,
67
+ "BLOCK_SIZE_M": 16,
68
68
  "BLOCK_SIZE_N": 128,
69
69
  "BLOCK_SIZE_K": 128,
70
- "GROUP_SIZE_M": 64,
70
+ "GROUP_SIZE_M": 32,
71
71
  "num_warps": 4,
72
72
  "num_stages": 3
73
73
  },
74
74
  "96": {
75
- "BLOCK_SIZE_M": 64,
75
+ "BLOCK_SIZE_M": 16,
76
76
  "BLOCK_SIZE_N": 128,
77
77
  "BLOCK_SIZE_K": 128,
78
- "GROUP_SIZE_M": 64,
78
+ "GROUP_SIZE_M": 32,
79
79
  "num_warps": 4,
80
80
  "num_stages": 3
81
81
  },
82
82
  "128": {
83
- "BLOCK_SIZE_M": 64,
83
+ "BLOCK_SIZE_M": 16,
84
84
  "BLOCK_SIZE_N": 128,
85
85
  "BLOCK_SIZE_K": 128,
86
- "GROUP_SIZE_M": 16,
86
+ "GROUP_SIZE_M": 64,
87
87
  "num_warps": 4,
88
88
  "num_stages": 3
89
89
  },
90
90
  "256": {
91
- "BLOCK_SIZE_M": 64,
91
+ "BLOCK_SIZE_M": 16,
92
92
  "BLOCK_SIZE_N": 128,
93
93
  "BLOCK_SIZE_K": 128,
94
- "GROUP_SIZE_M": 16,
94
+ "GROUP_SIZE_M": 64,
95
95
  "num_warps": 4,
96
96
  "num_stages": 3
97
97
  },
98
98
  "512": {
99
- "BLOCK_SIZE_M": 64,
99
+ "BLOCK_SIZE_M": 16,
100
100
  "BLOCK_SIZE_N": 128,
101
101
  "BLOCK_SIZE_K": 128,
102
102
  "GROUP_SIZE_M": 16,
@@ -107,9 +107,9 @@
107
107
  "BLOCK_SIZE_M": 64,
108
108
  "BLOCK_SIZE_N": 128,
109
109
  "BLOCK_SIZE_K": 128,
110
- "GROUP_SIZE_M": 32,
110
+ "GROUP_SIZE_M": 16,
111
111
  "num_warps": 4,
112
- "num_stages": 3
112
+ "num_stages": 4
113
113
  },
114
114
  "1536": {
115
115
  "BLOCK_SIZE_M": 64,
@@ -117,21 +117,21 @@
117
117
  "BLOCK_SIZE_K": 128,
118
118
  "GROUP_SIZE_M": 32,
119
119
  "num_warps": 4,
120
- "num_stages": 3
120
+ "num_stages": 4
121
121
  },
122
122
  "2048": {
123
123
  "BLOCK_SIZE_M": 64,
124
124
  "BLOCK_SIZE_N": 128,
125
125
  "BLOCK_SIZE_K": 128,
126
- "GROUP_SIZE_M": 16,
126
+ "GROUP_SIZE_M": 32,
127
127
  "num_warps": 4,
128
- "num_stages": 3
128
+ "num_stages": 4
129
129
  },
130
130
  "3072": {
131
- "BLOCK_SIZE_M": 128,
132
- "BLOCK_SIZE_N": 64,
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
133
  "BLOCK_SIZE_K": 128,
134
- "GROUP_SIZE_M": 32,
134
+ "GROUP_SIZE_M": 16,
135
135
  "num_warps": 4,
136
136
  "num_stages": 3
137
137
  },
@@ -139,8 +139,8 @@
139
139
  "BLOCK_SIZE_M": 64,
140
140
  "BLOCK_SIZE_N": 128,
141
141
  "BLOCK_SIZE_K": 128,
142
- "GROUP_SIZE_M": 64,
142
+ "GROUP_SIZE_M": 16,
143
143
  "num_warps": 4,
144
- "num_stages": 3
144
+ "num_stages": 4
145
145
  }
146
146
  }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 256,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 16,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }