sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -35,21 +35,6 @@ class Sampler(CustomOp):
35
35
  self.forward_native = self.forward_cuda
36
36
  self.is_torch_compile = False
37
37
 
38
- def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
39
- # min-token, presence, frequency
40
- if sampling_info.linear_penalties is not None:
41
- logits += sampling_info.linear_penalties
42
-
43
- # repetition
44
- if sampling_info.scaling_penalties is not None:
45
- logits = torch.where(
46
- logits > 0,
47
- logits / sampling_info.scaling_penalties,
48
- logits * sampling_info.scaling_penalties,
49
- )
50
-
51
- return logits
52
-
53
38
  def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
54
39
  # Post process logits
55
40
  logits = logits.contiguous()
@@ -58,14 +43,6 @@ class Sampler(CustomOp):
58
43
  # FIXME: Temporary workaround for unknown bugs in torch.compile
59
44
  logits.add_(0)
60
45
 
61
- if sampling_info.logit_bias is not None:
62
- logits.add_(sampling_info.logit_bias)
63
-
64
- if sampling_info.vocab_mask is not None:
65
- logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
66
-
67
- logits = self._apply_penalties(logits, sampling_info)
68
-
69
46
  return torch.softmax(logits, dim=-1)
70
47
 
71
48
  def forward_cuda(
@@ -78,7 +55,7 @@ class Sampler(CustomOp):
78
55
 
79
56
  probs = self._get_probs(logits, sampling_info)
80
57
 
81
- if not global_server_args_dict["disable_flashinfer_sampling"]:
58
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
82
59
  max_top_k_round, batch_size = 32, probs.shape[0]
83
60
  uniform_samples = torch.rand(
84
61
  (max_top_k_round, batch_size), device=probs.device
@@ -93,11 +70,15 @@ class Sampler(CustomOp):
93
70
  batch_next_token_ids, success = flashinfer_top_k_top_p(
94
71
  probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
95
72
  )
96
- else:
73
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
97
74
  # Here we provide a slower fallback implementation.
98
75
  batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
99
76
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
100
77
  )
78
+ else:
79
+ raise ValueError(
80
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
81
+ )
101
82
 
102
83
  return SampleOutput(success, probs, batch_next_token_ids)
103
84
 
@@ -0,0 +1,75 @@
1
+ """
2
+ Common utilities for torchao.
3
+ """
4
+
5
+ from typing import Dict, Set
6
+
7
+ import torch
8
+
9
+
10
+ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
11
+ """Quantize a Tensor with torchao quantization specified by torchao_config
12
+
13
+ Args:
14
+ `param`: weight parameter of the linear module
15
+ `torchao_config`: type of quantization and their arguments we want to use to
16
+ quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
17
+ 128
18
+ """
19
+ # Lazy import to suppress some warnings
20
+ from torchao.quantization import (
21
+ int4_weight_only,
22
+ int8_dynamic_activation_int8_weight,
23
+ int8_weight_only,
24
+ quantize_,
25
+ )
26
+
27
+ dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
28
+ dummy_linear.weight = param
29
+ if "int8wo" in torchao_config:
30
+ quantize_(dummy_linear, int8_weight_only())
31
+ elif "int8dq" in torchao_config:
32
+ quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
33
+ elif "int4wo" in torchao_config:
34
+ group_size = int(torchao_config.split("-")[-1])
35
+ assert group_size in [
36
+ 32,
37
+ 64,
38
+ 128,
39
+ 256,
40
+ ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
41
+ quantize_(dummy_linear, int4_weight_only(group_size=group_size))
42
+ elif "fp8wo" in torchao_config:
43
+ from torchao.quantization import float8_weight_only
44
+
45
+ # this requires newer hardware
46
+ # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
47
+ quantize_(dummy_linear, float8_weight_only())
48
+ return dummy_linear.weight
49
+
50
+
51
+ def apply_torchao_config_(
52
+ self: torch.nn.Module,
53
+ params_dict: Dict[str, torch.Tensor],
54
+ param_suffixes: Set[str],
55
+ ) -> None:
56
+ """A util function used for quantizing the weight parameters after they are loaded if
57
+ self.torchao_config is specified
58
+
59
+ Args:
60
+ `self`: the model we want to quantize
61
+ `params_dict`: dictionary mapping from param_name to the parameter Tensor
62
+ `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
63
+
64
+ Returns:
65
+ None, the `params_dict` is modified inplace and the weights of `self` model are quantized
66
+ """
67
+ if self.torchao_config:
68
+ for param_suffix in param_suffixes:
69
+ for name in params_dict:
70
+ param = params_dict[name]
71
+ if param_suffix in name and param.ndim == 2:
72
+ params_dict[name] = torchao_quantize_param_data(
73
+ param, self.torchao_config
74
+ )
75
+ self.load_state_dict(params_dict, assign=True)
@@ -15,24 +15,15 @@ limitations under the License.
15
15
 
16
16
  """
17
17
  Memory-efficient attention for decoding.
18
+ It supports page size = 1.
18
19
  """
19
20
 
20
21
  # Adapted from
21
22
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
22
23
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
23
- import torch
24
24
  import triton
25
25
  import triton.language as tl
26
26
 
27
- from sglang.srt.managers.schedule_batch import global_server_args_dict
28
-
29
- if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
30
- REDUCE_TRITON_TYPE = tl.float32
31
- REDUCE_TORCH_TYPE = torch.float32
32
- else:
33
- REDUCE_TRITON_TYPE = tl.float16
34
- REDUCE_TORCH_TYPE = torch.float16
35
-
36
27
 
37
28
  @triton.jit
38
29
  def tanh(x):
@@ -60,11 +51,13 @@ def _fwd_kernel_stage1(
60
51
  BLOCK_DMODEL: tl.constexpr,
61
52
  BLOCK_N: tl.constexpr,
62
53
  logit_cap: tl.constexpr,
54
+ Lk: tl.constexpr,
63
55
  ):
64
56
  cur_batch = tl.program_id(0)
65
57
  cur_head = tl.program_id(1)
66
58
  start_n = tl.program_id(2)
67
59
 
60
+ reduce_dtype = Att_Out.dtype.element_ty
68
61
  cur_kv_head = cur_head // kv_group_num
69
62
 
70
63
  offs_d = tl.arange(0, BLOCK_DMODEL)
@@ -83,7 +76,7 @@ def _fwd_kernel_stage1(
83
76
  block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
84
77
 
85
78
  for start_mark in range(0, block_mask, 1):
86
- q = tl.load(Q + off_q + start_mark).to(REDUCE_TRITON_TYPE)
79
+ q = tl.load(Q + off_q + start_mark).to(reduce_dtype)
87
80
  offs_n_new = cur_batch_start_index + offs_n
88
81
  k_loc = tl.load(
89
82
  Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
@@ -97,9 +90,9 @@ def _fwd_kernel_stage1(
97
90
  )
98
91
  k = tl.load(
99
92
  K_Buffer + offs_buf_k,
100
- mask=offs_n_new[:, None] < cur_batch_end_index,
93
+ mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
101
94
  other=0.0,
102
- ).to(REDUCE_TRITON_TYPE)
95
+ ).to(reduce_dtype)
103
96
  att_value = tl.sum(q[None, :] * k, 1)
104
97
  att_value *= sm_scale
105
98
 
@@ -112,7 +105,7 @@ def _fwd_kernel_stage1(
112
105
 
113
106
  @triton.jit
114
107
  def _fwd_kernel_stage2(
115
- Logics,
108
+ logits,
116
109
  V_Buffer,
117
110
  Out,
118
111
  Req_to_tokens,
@@ -128,6 +121,7 @@ def _fwd_kernel_stage2(
128
121
  kv_group_num: tl.constexpr,
129
122
  BLOCK_DMODEL: tl.constexpr,
130
123
  BLOCK_N: tl.constexpr,
124
+ Lv: tl.constexpr,
131
125
  ):
132
126
  cur_batch = tl.program_id(0)
133
127
  cur_head = tl.program_id(1)
@@ -159,7 +153,7 @@ def _fwd_kernel_stage2(
159
153
  )
160
154
 
161
155
  qk = tl.load(
162
- Logics
156
+ logits
163
157
  + cur_head * stride_logic_h
164
158
  + (cur_batch_start_loc + start_n + offs_n),
165
159
  mask=start_n + offs_n < cur_batch_seq_len,
@@ -170,14 +164,16 @@ def _fwd_kernel_stage2(
170
164
  old_scale = tl.exp(e_max - n_e_max)
171
165
  p = tl.exp(qk - n_e_max)
172
166
  e_sum = e_sum * old_scale + tl.sum(p, 0)
173
- v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
167
+ v = tl.load(
168
+ v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
169
+ )
174
170
  acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
175
171
  e_max = n_e_max
176
172
 
177
173
  acc = acc / e_sum
178
174
  off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
179
175
  out_ptrs = Out + off_o
180
- tl.store(out_ptrs, acc)
176
+ tl.store(out_ptrs, acc, mask=(offs_d < Lv))
181
177
 
182
178
 
183
179
  def _decode_att_m_fwd(
@@ -193,10 +189,7 @@ def _decode_att_m_fwd(
193
189
  logit_cap,
194
190
  ):
195
191
  BLOCK = 32
196
- # shape constraints
197
- Lq, Lk = q.shape[-1], k_buffer.shape[-1]
198
- assert Lq == Lk
199
- assert Lk in {16, 32, 64, 128, 256}
192
+ Lk = k_buffer.shape[-1]
200
193
 
201
194
  batch, head_num = B_req_idx.shape[0], q.shape[1]
202
195
 
@@ -208,6 +201,8 @@ def _decode_att_m_fwd(
208
201
  else:
209
202
  num_warps = 2
210
203
 
204
+ BLOCK_DMODEL = triton.next_power_of_2(Lk)
205
+
211
206
  _fwd_kernel_stage1[grid](
212
207
  q,
213
208
  k_buffer,
@@ -224,16 +219,17 @@ def _decode_att_m_fwd(
224
219
  k_buffer.stride(1),
225
220
  att_out.stride(0),
226
221
  kv_group_num=kv_group_num,
227
- BLOCK_DMODEL=Lk,
222
+ BLOCK_DMODEL=BLOCK_DMODEL,
228
223
  BLOCK_N=BLOCK,
229
224
  logit_cap=logit_cap,
230
225
  num_warps=num_warps,
231
226
  num_stages=1,
227
+ Lk=Lk,
232
228
  )
233
229
 
234
230
 
235
231
  def _decode_softmax_reducev_fwd(
236
- logics,
232
+ logits,
237
233
  v_buffer,
238
234
  o,
239
235
  req_to_tokens,
@@ -242,31 +238,35 @@ def _decode_softmax_reducev_fwd(
242
238
  b_seq_len,
243
239
  ):
244
240
  BLOCK = 64
245
- batch, head = b_seq_len.shape[0], logics.shape[0]
241
+ batch, head = b_seq_len.shape[0], logits.shape[0]
246
242
  grid = (batch, head, 1)
247
- kv_group_num = logics.shape[0] // v_buffer.shape[1]
243
+ kv_group_num = logits.shape[0] // v_buffer.shape[1]
248
244
 
249
245
  num_warps = 1
250
246
 
247
+ Lv = v_buffer.shape[-1]
248
+ BLOCK_DMODEL = triton.next_power_of_2(Lv)
249
+
251
250
  _fwd_kernel_stage2[grid](
252
- logics,
251
+ logits,
253
252
  v_buffer,
254
253
  o,
255
254
  req_to_tokens,
256
255
  b_req_idx,
257
256
  b_start_loc,
258
257
  b_seq_len,
259
- logics.stride(0),
258
+ logits.stride(0),
260
259
  v_buffer.stride(0),
261
260
  v_buffer.stride(1),
262
261
  o.stride(0),
263
262
  o.stride(1),
264
263
  req_to_tokens.stride(0),
265
264
  kv_group_num=kv_group_num,
266
- BLOCK_DMODEL=v_buffer.shape[-1],
265
+ BLOCK_DMODEL=BLOCK_DMODEL,
267
266
  BLOCK_N=BLOCK,
268
267
  num_warps=num_warps,
269
268
  num_stages=3,
269
+ Lv=Lv,
270
270
  )
271
271
 
272
272
 
@@ -293,11 +293,13 @@ def _fwd_grouped_kernel_stage1(
293
293
  BLOCK_N: tl.constexpr,
294
294
  BLOCK_H: tl.constexpr,
295
295
  logit_cap: tl.constexpr,
296
+ Lk: tl.constexpr,
296
297
  ):
297
298
  cur_batch = tl.program_id(0)
298
299
  cur_kv_head = tl.program_id(1)
299
300
  start_n = tl.program_id(2)
300
301
 
302
+ reduce_dtype = Att_Out.dtype.element_ty
301
303
  cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
302
304
  mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
303
305
  mask_h = mask_h & (cur_head < q_head_num)
@@ -324,9 +326,9 @@ def _fwd_grouped_kernel_stage1(
324
326
  block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
325
327
 
326
328
  for start_mark in range(0, block_mask, 1):
327
- q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
328
- REDUCE_TRITON_TYPE
329
- )
329
+ q = tl.load(
330
+ Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
331
+ ).to(reduce_dtype)
330
332
  offs_n_new = cur_batch_start_index + offs_n
331
333
  k_loc = tl.load(
332
334
  Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
@@ -340,13 +342,13 @@ def _fwd_grouped_kernel_stage1(
340
342
  )
341
343
  k = tl.load(
342
344
  K_Buffer + offs_buf_k,
343
- mask=offs_n_new[None, :] < cur_batch_end_index,
345
+ mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
344
346
  other=0.0,
345
- ).to(REDUCE_TRITON_TYPE)
347
+ ).to(reduce_dtype)
346
348
  qk = tl.dot(q, k)
347
349
  if BLOCK_DPE > 0:
348
350
  qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
349
- REDUCE_TRITON_TYPE
351
+ reduce_dtype
350
352
  )
351
353
  offs_buf_kpe = (
352
354
  k_loc[None, :] * stride_buf_kbs
@@ -357,7 +359,7 @@ def _fwd_grouped_kernel_stage1(
357
359
  K_Buffer + offs_buf_kpe,
358
360
  mask=offs_n_new[None, :] < cur_batch_end_index,
359
361
  other=0.0,
360
- ).to(REDUCE_TRITON_TYPE)
362
+ ).to(reduce_dtype)
361
363
  qk += tl.dot(qpe, kpe)
362
364
  qk *= sm_scale
363
365
 
@@ -377,7 +379,7 @@ def _fwd_grouped_kernel_stage1(
377
379
 
378
380
  @triton.jit
379
381
  def _fwd_grouped_kernel_stage2(
380
- Logics,
382
+ logits,
381
383
  V_Buffer,
382
384
  Out,
383
385
  Req_to_tokens,
@@ -395,6 +397,7 @@ def _fwd_grouped_kernel_stage2(
395
397
  BLOCK_DMODEL: tl.constexpr,
396
398
  BLOCK_N: tl.constexpr,
397
399
  BLOCK_H: tl.constexpr,
400
+ Lv: tl.constexpr,
398
401
  ):
399
402
  cur_batch = tl.program_id(0)
400
403
  cur_kv_head = tl.program_id(1)
@@ -432,7 +435,7 @@ def _fwd_grouped_kernel_stage2(
432
435
  )
433
436
 
434
437
  qk = tl.load(
435
- Logics + offs_qk,
438
+ logits + offs_qk,
436
439
  mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
437
440
  other=float("-inf"),
438
441
  )
@@ -441,7 +444,9 @@ def _fwd_grouped_kernel_stage2(
441
444
  old_scale = tl.exp(e_max - n_e_max)
442
445
  p = tl.exp(qk - n_e_max[:, None])
443
446
  e_sum = e_sum * old_scale + tl.sum(p, 1)
444
- v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
447
+ v = tl.load(
448
+ v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
449
+ )
445
450
  p = p.to(v.dtype)
446
451
  acc = acc * old_scale[:, None] + tl.dot(p, v)
447
452
  e_max = n_e_max
@@ -449,7 +454,7 @@ def _fwd_grouped_kernel_stage2(
449
454
  acc = acc / e_sum[:, None]
450
455
  off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
451
456
  out_ptrs = Out + off_o
452
- tl.store(out_ptrs, acc, mask=mask_h[:, None])
457
+ tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
453
458
 
454
459
 
455
460
  def _decode_grouped_att_m_fwd(
@@ -464,17 +469,17 @@ def _decode_grouped_att_m_fwd(
464
469
  sm_scale,
465
470
  logit_cap,
466
471
  ):
467
- BLOCK = 32
468
- # shape constraints
469
- Lq, Lk = q.shape[-1], k_buffer.shape[-1]
470
- assert Lq == Lk
471
- assert Lk in {16, 32, 64, 128, 256, 576}
472
+ BLOCK = 64
473
+ Lk = k_buffer.shape[-1]
472
474
 
473
475
  if Lk == 576:
474
476
  BLOCK_DMODEL = 512
475
477
  BLOCK_DPE = 64
478
+ elif Lk == 288:
479
+ BLOCK_DMODEL = 256
480
+ BLOCK_DPE = 32
476
481
  else:
477
- BLOCK_DMODEL = Lk
482
+ BLOCK_DMODEL = triton.next_power_of_2(Lk)
478
483
  BLOCK_DPE = 0
479
484
 
480
485
  batch, head_num = B_req_idx.shape[0], q.shape[1]
@@ -513,11 +518,12 @@ def _decode_grouped_att_m_fwd(
513
518
  logit_cap=logit_cap,
514
519
  num_warps=num_warps,
515
520
  num_stages=1,
521
+ Lk=Lk,
516
522
  )
517
523
 
518
524
 
519
525
  def _decode_grouped_softmax_reducev_fwd(
520
- logics,
526
+ logits,
521
527
  v_buffer,
522
528
  o,
523
529
  req_to_tokens,
@@ -526,22 +532,25 @@ def _decode_grouped_softmax_reducev_fwd(
526
532
  b_seq_len,
527
533
  ):
528
534
  BLOCK = 128
529
- batch, head_num = b_seq_len.shape[0], logics.shape[0]
530
- kv_group_num = logics.shape[0] // v_buffer.shape[1]
535
+ batch, head_num = b_seq_len.shape[0], logits.shape[0]
536
+ kv_group_num = logits.shape[0] // v_buffer.shape[1]
531
537
  BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
532
538
  grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
533
539
 
534
540
  num_warps = 8
535
541
 
542
+ Lv = v_buffer.shape[-1]
543
+ BLOCK_DMODEL = triton.next_power_of_2(Lv)
544
+
536
545
  _fwd_grouped_kernel_stage2[grid](
537
- logics,
546
+ logits,
538
547
  v_buffer,
539
548
  o,
540
549
  req_to_tokens,
541
550
  b_req_idx,
542
551
  b_start_loc,
543
552
  b_seq_len,
544
- logics.stride(0),
553
+ logits.stride(0),
545
554
  v_buffer.stride(0),
546
555
  v_buffer.stride(1),
547
556
  o.stride(0),
@@ -549,9 +558,10 @@ def _decode_grouped_softmax_reducev_fwd(
549
558
  req_to_tokens.stride(0),
550
559
  kv_group_num=kv_group_num,
551
560
  q_head_num=head_num,
552
- BLOCK_DMODEL=v_buffer.shape[-1],
561
+ BLOCK_DMODEL=BLOCK_DMODEL,
553
562
  BLOCK_N=BLOCK,
554
563
  BLOCK_H=BLOCK_H,
564
+ Lv=Lv,
555
565
  num_warps=num_warps,
556
566
  num_stages=1,
557
567
  )
@@ -566,17 +576,11 @@ def decode_attention_fwd(
566
576
  b_req_idx,
567
577
  b_start_loc,
568
578
  b_seq_len,
579
+ attn_logits,
569
580
  max_len_in_batch,
570
- total_num_tokens,
571
581
  sm_scale,
572
- logit_cap=-1,
573
- att_m=None,
582
+ logit_cap=0.0,
574
583
  ):
575
- if att_m is None:
576
- att_m = torch.empty(
577
- (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
578
- )
579
-
580
584
  kv_group_num = q.shape[1] // v_buffer.shape[1]
581
585
 
582
586
  if kv_group_num == 1:
@@ -584,7 +588,7 @@ def decode_attention_fwd(
584
588
  _decode_att_m_fwd(
585
589
  q,
586
590
  k_buffer,
587
- att_m,
591
+ attn_logits,
588
592
  req_to_token,
589
593
  b_req_idx,
590
594
  b_start_loc,
@@ -594,7 +598,7 @@ def decode_attention_fwd(
594
598
  logit_cap,
595
599
  )
596
600
  _decode_softmax_reducev_fwd(
597
- att_m,
601
+ attn_logits,
598
602
  v_buffer,
599
603
  o,
600
604
  req_to_token,
@@ -607,7 +611,7 @@ def decode_attention_fwd(
607
611
  _decode_grouped_att_m_fwd(
608
612
  q,
609
613
  k_buffer,
610
- att_m,
614
+ attn_logits,
611
615
  req_to_token,
612
616
  b_req_idx,
613
617
  b_start_loc,
@@ -617,7 +621,7 @@ def decode_attention_fwd(
617
621
  logit_cap,
618
622
  )
619
623
  _decode_grouped_softmax_reducev_fwd(
620
- att_m,
624
+ attn_logits,
621
625
  v_buffer,
622
626
  o,
623
627
  req_to_token,