sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,11 @@ It supports page size = 1.
17
17
  """
18
18
 
19
19
  # Adapted from
20
- # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
21
- # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
20
+ # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
21
+ # https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
22
+
23
+ import logging
24
+
22
25
  import triton
23
26
  import triton.language as tl
24
27
 
@@ -26,6 +29,13 @@ from sglang.srt.utils import is_hip
26
29
 
27
30
  is_hip_ = is_hip()
28
31
 
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy.
35
+ logger.warning(
36
+ "The following error message 'operation scheduled before its operands' can be ignored."
37
+ )
38
+
29
39
 
30
40
  @triton.jit
31
41
  def tanh(x):
@@ -37,10 +47,10 @@ def tanh(x):
37
47
  def _fwd_kernel_stage1(
38
48
  Q,
39
49
  K_Buffer,
50
+ V_Buffer,
40
51
  sm_scale,
41
52
  Req_to_tokens,
42
53
  B_req_idx,
43
- B_Start_Loc,
44
54
  B_Seqlen,
45
55
  Att_Out,
46
56
  stride_req_to_tokens_b,
@@ -48,152 +58,136 @@ def _fwd_kernel_stage1(
48
58
  stride_qh,
49
59
  stride_buf_kbs,
50
60
  stride_buf_kh,
51
- att_stride_h,
61
+ stride_buf_vbs,
62
+ stride_buf_vh,
63
+ stride_mid_ob,
64
+ stride_mid_oh,
65
+ stride_mid_os,
52
66
  kv_group_num: tl.constexpr,
53
67
  BLOCK_DMODEL: tl.constexpr,
68
+ BLOCK_DV: tl.constexpr,
54
69
  BLOCK_N: tl.constexpr,
55
- SPLIT_K: tl.constexpr,
70
+ NUM_KV_SPLITS: tl.constexpr,
56
71
  logit_cap: tl.constexpr,
57
72
  Lk: tl.constexpr,
73
+ Lv: tl.constexpr,
58
74
  ):
59
75
  cur_batch = tl.program_id(0)
60
76
  cur_head = tl.program_id(1)
61
- split_k_id = tl.program_id(2)
77
+ split_kv_id = tl.program_id(2)
62
78
 
63
- reduce_dtype = Att_Out.dtype.element_ty
64
79
  cur_kv_head = cur_head // kv_group_num
65
80
 
66
81
  offs_d = tl.arange(0, BLOCK_DMODEL)
82
+ offs_dv = tl.arange(0, BLOCK_DV)
83
+ mask_d = offs_d < Lk
84
+ mask_dv = offs_dv < Lv
67
85
  cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
68
- cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
69
86
  cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
70
87
 
71
88
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
72
- q = tl.load(Q + off_q).to(reduce_dtype)
73
-
74
- kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
75
- split_k_start = kv_len_per_split * split_k_id
76
- split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
77
-
78
- for start_n in range(split_k_start, split_k_end, BLOCK_N):
79
- offs_n = start_n + tl.arange(0, BLOCK_N)
80
- k_loc = tl.load(
81
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
82
- mask=offs_n < split_k_end,
83
- other=0,
84
- )
85
- offs_buf_k = (
86
- k_loc[:, None] * stride_buf_kbs
87
- + cur_kv_head * stride_buf_kh
88
- + offs_d[None, :]
89
- )
90
- k = tl.load(
91
- K_Buffer + offs_buf_k,
92
- mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk),
93
- other=0.0,
94
- ).to(reduce_dtype)
95
- att_value = tl.sum(q[None, :] * k, 1)
96
- att_value *= sm_scale
89
+ q = tl.load(Q + off_q, mask=mask_d, other=0.0)
97
90
 
98
- if logit_cap > 0:
99
- att_value = logit_cap * tanh(att_value / logit_cap)
91
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
92
+ split_kv_start = kv_len_per_split * split_kv_id
93
+ split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
100
94
 
101
- off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
102
- tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end)
95
+ e_max = -float("inf")
96
+ e_sum = 0.0
97
+ acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
98
+
99
+ if split_kv_end > split_kv_start:
100
+ for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
101
+ offs_n = start_n + tl.arange(0, BLOCK_N)
102
+ kv_loc = tl.load(
103
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
104
+ mask=offs_n < split_kv_end,
105
+ other=0,
106
+ )
107
+ offs_buf_k = (
108
+ kv_loc[:, None] * stride_buf_kbs
109
+ + cur_kv_head * stride_buf_kh
110
+ + offs_d[None, :]
111
+ )
112
+ k = tl.load(
113
+ K_Buffer + offs_buf_k,
114
+ mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
115
+ other=0.0,
116
+ )
117
+ qk = tl.sum(q[None, :] * k, 1)
118
+ qk *= sm_scale
103
119
 
120
+ if logit_cap > 0:
121
+ qk = logit_cap * tanh(qk / logit_cap)
104
122
 
105
- @triton.jit
106
- def _fwd_kernel_stage2(
107
- logits,
108
- V_Buffer,
109
- Out,
110
- Req_to_tokens,
111
- B_req_idx,
112
- B_Start_Loc,
113
- B_Seqlen,
114
- stride_logic_h,
115
- stride_buf_vbs,
116
- stride_buf_vh,
117
- stride_obs,
118
- stride_oh,
119
- stride_req_to_token_b,
120
- kv_group_num: tl.constexpr,
121
- BLOCK_DMODEL: tl.constexpr,
122
- BLOCK_N: tl.constexpr,
123
- Lv: tl.constexpr,
124
- ):
125
- cur_batch = tl.program_id(0)
126
- cur_head = tl.program_id(1)
123
+ qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
127
124
 
128
- cur_kv_head = cur_head // kv_group_num
129
-
130
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
131
- cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
132
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
125
+ offs_buf_v = (
126
+ kv_loc[:, None] * stride_buf_vbs
127
+ + cur_kv_head * stride_buf_vh
128
+ + offs_dv[None, :]
129
+ )
130
+ v = tl.load(
131
+ V_Buffer + offs_buf_v,
132
+ mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
133
+ other=0.0,
134
+ )
133
135
 
134
- offs_n = tl.arange(0, BLOCK_N)
135
- offs_d = tl.arange(0, BLOCK_DMODEL)
136
+ n_e_max = tl.maximum(tl.max(qk, 0), e_max)
137
+ re_scale = tl.exp(e_max - n_e_max)
138
+ p = tl.exp(qk - n_e_max)
139
+ acc *= re_scale
140
+ acc += tl.sum(p[:, None] * v, 0)
136
141
 
137
- offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
138
- v_ptrs = V_Buffer + offs_buf_v
142
+ e_sum = e_sum * re_scale + tl.sum(p, 0)
143
+ e_max = n_e_max
139
144
 
140
- e_max = float("-inf")
141
- e_sum = 0.0
142
- acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
143
-
144
- for start_n in range(0, cur_batch_seq_len, BLOCK_N):
145
- start_n = tl.multiple_of(start_n, BLOCK_N)
146
- v_index = tl.load(
147
- Req_to_tokens
148
- + cur_batch_req_idx * stride_req_to_token_b
149
- + (start_n + offs_n),
150
- mask=(start_n + offs_n) < cur_batch_seq_len,
151
- other=0,
145
+ offs_mid_o = (
146
+ cur_batch * stride_mid_ob
147
+ + cur_head * stride_mid_oh
148
+ + split_kv_id * stride_mid_os
149
+ + offs_dv
152
150
  )
153
151
 
154
- qk = tl.load(
155
- logits
156
- + cur_head * stride_logic_h
157
- + (cur_batch_start_loc + start_n + offs_n),
158
- mask=start_n + offs_n < cur_batch_seq_len,
159
- other=float("-inf"),
152
+ tl.store(
153
+ Att_Out + offs_mid_o,
154
+ acc / e_sum,
155
+ mask=(mask_dv),
160
156
  )
161
157
 
162
- n_e_max = tl.maximum(tl.max(qk, 0), e_max)
163
- old_scale = tl.exp(e_max - n_e_max)
164
- p = tl.exp(qk - n_e_max)
165
- e_sum = e_sum * old_scale + tl.sum(p, 0)
166
- v = tl.load(
167
- v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
158
+ offs_mid_o_1 = (
159
+ cur_batch * stride_mid_ob
160
+ + cur_head * stride_mid_oh
161
+ + split_kv_id * stride_mid_os
162
+ + Lv
168
163
  )
169
- acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
170
- e_max = n_e_max
171
164
 
172
- acc = acc / e_sum
173
- off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
174
- out_ptrs = Out + off_o
175
- tl.store(out_ptrs, acc, mask=(offs_d < Lv))
165
+ tl.store(
166
+ Att_Out + offs_mid_o_1,
167
+ e_max + tl.log(e_sum),
168
+ )
176
169
 
177
170
 
178
171
  def _decode_att_m_fwd(
179
172
  q,
180
173
  k_buffer,
174
+ v_buffer,
181
175
  att_out,
182
176
  Req_to_tokens,
183
177
  B_req_idx,
184
- B_Start_Loc,
185
178
  B_Seqlen,
186
- max_len_in_batch,
179
+ num_kv_splits,
187
180
  sm_scale,
188
181
  logit_cap,
189
182
  ):
190
- BLOCK = 32
191
- SPLIT_K = 8
183
+ BLOCK = 64
184
+ NUM_KV_SPLITS = num_kv_splits
192
185
  Lk = k_buffer.shape[-1]
186
+ Lv = v_buffer.shape[-1]
193
187
 
194
188
  batch, head_num = B_req_idx.shape[0], q.shape[1]
195
189
 
196
- grid = (batch, head_num, SPLIT_K)
190
+ grid = (batch, head_num, NUM_KV_SPLITS)
197
191
  kv_group_num = q.shape[1] // k_buffer.shape[1]
198
192
 
199
193
  if kv_group_num == 1:
@@ -202,14 +196,15 @@ def _decode_att_m_fwd(
202
196
  num_warps = 2
203
197
 
204
198
  BLOCK_DMODEL = triton.next_power_of_2(Lk)
199
+ BLOCK_DV = triton.next_power_of_2(Lv)
205
200
 
206
201
  _fwd_kernel_stage1[grid](
207
202
  q,
208
203
  k_buffer,
204
+ v_buffer,
209
205
  sm_scale,
210
206
  Req_to_tokens,
211
207
  B_req_idx,
212
- B_Start_Loc,
213
208
  B_Seqlen,
214
209
  att_out,
215
210
  Req_to_tokens.stride(0),
@@ -217,56 +212,20 @@ def _decode_att_m_fwd(
217
212
  q.stride(1),
218
213
  k_buffer.stride(0),
219
214
  k_buffer.stride(1),
215
+ v_buffer.stride(0),
216
+ v_buffer.stride(1),
220
217
  att_out.stride(0),
218
+ att_out.stride(1),
219
+ att_out.stride(2),
221
220
  kv_group_num=kv_group_num,
222
221
  BLOCK_DMODEL=BLOCK_DMODEL,
222
+ BLOCK_DV=BLOCK_DV,
223
223
  BLOCK_N=BLOCK,
224
- SPLIT_K=SPLIT_K,
224
+ NUM_KV_SPLITS=NUM_KV_SPLITS,
225
225
  logit_cap=logit_cap,
226
226
  num_warps=num_warps,
227
- num_stages=1,
227
+ num_stages=2,
228
228
  Lk=Lk,
229
- )
230
-
231
-
232
- def _decode_softmax_reducev_fwd(
233
- logits,
234
- v_buffer,
235
- o,
236
- req_to_tokens,
237
- b_req_idx,
238
- b_start_loc,
239
- b_seq_len,
240
- ):
241
- BLOCK = 64
242
- batch, head = b_seq_len.shape[0], logits.shape[0]
243
- grid = (batch, head, 1)
244
- kv_group_num = logits.shape[0] // v_buffer.shape[1]
245
-
246
- num_warps = 1
247
-
248
- Lv = v_buffer.shape[-1]
249
- BLOCK_DMODEL = triton.next_power_of_2(Lv)
250
-
251
- _fwd_kernel_stage2[grid](
252
- logits,
253
- v_buffer,
254
- o,
255
- req_to_tokens,
256
- b_req_idx,
257
- b_start_loc,
258
- b_seq_len,
259
- logits.stride(0),
260
- v_buffer.stride(0),
261
- v_buffer.stride(1),
262
- o.stride(0),
263
- o.stride(1),
264
- req_to_tokens.stride(0),
265
- kv_group_num=kv_group_num,
266
- BLOCK_DMODEL=BLOCK_DMODEL,
267
- BLOCK_N=BLOCK,
268
- num_warps=num_warps,
269
- num_stages=3,
270
229
  Lv=Lv,
271
230
  )
272
231
 
@@ -275,10 +234,10 @@ def _decode_softmax_reducev_fwd(
275
234
  def _fwd_grouped_kernel_stage1(
276
235
  Q,
277
236
  K_Buffer,
237
+ V_Buffer,
278
238
  sm_scale,
279
239
  Req_to_tokens,
280
240
  B_req_idx,
281
- B_Start_Loc,
282
241
  B_Seqlen,
283
242
  Att_Out,
284
243
  stride_req_to_tokens_b,
@@ -286,23 +245,27 @@ def _fwd_grouped_kernel_stage1(
286
245
  stride_qh,
287
246
  stride_buf_kbs,
288
247
  stride_buf_kh,
289
- att_stride_h,
248
+ stride_buf_vbs,
249
+ stride_buf_vh,
250
+ stride_mid_ob,
251
+ stride_mid_oh,
252
+ stride_mid_os,
290
253
  kv_group_num: tl.constexpr,
291
254
  q_head_num: tl.constexpr,
292
255
  BLOCK_DMODEL: tl.constexpr,
293
256
  BLOCK_DPE: tl.constexpr,
257
+ BLOCK_DV: tl.constexpr,
294
258
  BLOCK_N: tl.constexpr,
295
259
  BLOCK_H: tl.constexpr,
296
- SPLIT_K: tl.constexpr,
260
+ NUM_KV_SPLITS: tl.constexpr,
297
261
  logit_cap: tl.constexpr,
298
262
  Lk: tl.constexpr,
263
+ Lv: tl.constexpr,
299
264
  ):
300
265
  cur_batch = tl.program_id(0)
301
266
  cur_head_id = tl.program_id(1)
302
267
  cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
303
- split_k_id = tl.program_id(2)
304
-
305
- reduce_dtype = Att_Out.dtype.element_ty
268
+ split_kv_id = tl.program_id(2)
306
269
 
307
270
  if BLOCK_H < kv_group_num:
308
271
  VALID_BLOCK_H: tl.constexpr = BLOCK_H
@@ -313,171 +276,135 @@ def _fwd_grouped_kernel_stage1(
313
276
  mask_h = mask_h & (cur_head < q_head_num)
314
277
 
315
278
  offs_d = tl.arange(0, BLOCK_DMODEL)
279
+ offs_dv = tl.arange(0, BLOCK_DV)
280
+ mask_d = offs_d < Lk
281
+ mask_dv = offs_dv < Lv
316
282
  cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
317
- cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
318
283
  cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
319
284
 
320
285
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
321
- q = tl.load(
322
- Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0
323
- ).to(reduce_dtype)
286
+ q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
324
287
 
325
288
  if BLOCK_DPE > 0:
326
289
  offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
290
+ mask_dpe = offs_dpe < Lk
327
291
  off_qpe = (
328
292
  cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
329
293
  )
330
- qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype)
331
-
332
- kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K)
333
- split_k_start = kv_len_per_split * split_k_id
334
- split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len)
335
-
336
- for start_n in range(split_k_start, split_k_end, BLOCK_N):
337
- offs_n = start_n + tl.arange(0, BLOCK_N)
338
- k_loc = tl.load(
339
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
340
- mask=offs_n < split_k_end,
341
- other=0,
342
- )
343
- offs_buf_k = (
344
- k_loc[None, :] * stride_buf_kbs
345
- + cur_kv_head * stride_buf_kh
346
- + offs_d[:, None]
294
+ qpe = tl.load(
295
+ Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
347
296
  )
348
- k = tl.load(
349
- K_Buffer + offs_buf_k,
350
- mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk),
351
- other=0.0,
352
- ).to(reduce_dtype)
353
- qk = tl.dot(q, k)
354
- if BLOCK_DPE > 0:
355
- offs_buf_kpe = (
356
- k_loc[None, :] * stride_buf_kbs
297
+
298
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
299
+ split_kv_start = kv_len_per_split * split_kv_id
300
+ split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
301
+
302
+ e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
303
+ e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
304
+ acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
305
+
306
+ if split_kv_end > split_kv_start:
307
+ for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
308
+ offs_n = start_n + tl.arange(0, BLOCK_N)
309
+ kv_loc = tl.load(
310
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
311
+ mask=offs_n < split_kv_end,
312
+ other=0,
313
+ )
314
+ offs_buf_k = (
315
+ kv_loc[None, :] * stride_buf_kbs
357
316
  + cur_kv_head * stride_buf_kh
358
- + offs_dpe[:, None]
317
+ + offs_d[:, None]
359
318
  )
360
- kpe = tl.load(
361
- K_Buffer + offs_buf_kpe,
362
- mask=offs_n[None, :] < split_k_end,
319
+ k = tl.load(
320
+ K_Buffer + offs_buf_k,
321
+ mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
363
322
  other=0.0,
364
- ).to(reduce_dtype)
365
- qk += tl.dot(qpe, kpe)
366
- qk *= sm_scale
367
-
368
- if logit_cap > 0:
369
- qk = logit_cap * tanh(qk / logit_cap)
370
-
371
- offs_o = cur_head[:, None] * att_stride_h + (
372
- cur_batch_in_all_start_index + offs_n[None, :]
373
- )
374
-
375
- tl.store(
376
- Att_Out + offs_o,
377
- qk,
378
- mask=mask_h[:, None] & (offs_n[None, :] < split_k_end),
379
- )
380
-
381
-
382
- @triton.jit
383
- def _fwd_grouped_kernel_stage2(
384
- logits,
385
- V_Buffer,
386
- Out,
387
- Req_to_tokens,
388
- B_req_idx,
389
- B_Start_Loc,
390
- B_Seqlen,
391
- stride_logic_h,
392
- stride_buf_vbs,
393
- stride_buf_vh,
394
- stride_obs,
395
- stride_oh,
396
- stride_req_to_token_b,
397
- kv_group_num: tl.constexpr,
398
- q_head_num: tl.constexpr,
399
- BLOCK_DMODEL: tl.constexpr,
400
- BLOCK_N: tl.constexpr,
401
- BLOCK_H: tl.constexpr,
402
- Lv: tl.constexpr,
403
- ):
404
- cur_batch = tl.program_id(0)
405
- cur_head_id = tl.program_id(1)
406
- cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
407
-
408
- if BLOCK_H < kv_group_num:
409
- VALID_BLOCK_H: tl.constexpr = BLOCK_H
410
- else:
411
- VALID_BLOCK_H: tl.constexpr = kv_group_num
412
- cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
413
- mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
414
- mask_h = mask_h & (cur_head < q_head_num)
323
+ )
324
+ qk = tl.dot(q, k.to(q.dtype))
325
+ if BLOCK_DPE > 0:
326
+ offs_buf_kpe = (
327
+ kv_loc[None, :] * stride_buf_kbs
328
+ + cur_kv_head * stride_buf_kh
329
+ + offs_dpe[:, None]
330
+ )
331
+ kpe = tl.load(
332
+ K_Buffer + offs_buf_kpe,
333
+ mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
334
+ other=0.0,
335
+ )
336
+ qk += tl.dot(qpe, kpe.to(qpe.dtype))
337
+ qk *= sm_scale
338
+
339
+ if logit_cap > 0:
340
+ qk = logit_cap * tanh(qk / logit_cap)
341
+
342
+ qk = tl.where(
343
+ mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
344
+ )
415
345
 
416
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
417
- cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
418
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
346
+ offs_buf_v = (
347
+ kv_loc[:, None] * stride_buf_vbs
348
+ + cur_kv_head * stride_buf_vh
349
+ + offs_dv[None, :]
350
+ )
351
+ v = tl.load(
352
+ V_Buffer + offs_buf_v,
353
+ mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
354
+ other=0.0,
355
+ )
419
356
 
420
- offs_n = tl.arange(0, BLOCK_N)
421
- offs_d = tl.arange(0, BLOCK_DMODEL)
357
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
358
+ re_scale = tl.exp(e_max - n_e_max)
359
+ p = tl.exp(qk - n_e_max[:, None])
360
+ acc *= re_scale[:, None]
361
+ acc += tl.dot(p.to(v.dtype), v)
422
362
 
423
- offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
424
- v_ptrs = V_Buffer + offs_buf_v
363
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
364
+ e_max = n_e_max
425
365
 
426
- e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
427
- e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
428
- acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
429
-
430
- for start_n in range(0, cur_batch_seq_len, BLOCK_N):
431
- start_n = tl.multiple_of(start_n, BLOCK_N)
432
- v_index = tl.load(
433
- Req_to_tokens
434
- + cur_batch_req_idx * stride_req_to_token_b
435
- + (start_n + offs_n),
436
- mask=(start_n + offs_n) < cur_batch_seq_len,
437
- other=0,
366
+ offs_mid_o = (
367
+ cur_batch * stride_mid_ob
368
+ + cur_head[:, None] * stride_mid_oh
369
+ + split_kv_id * stride_mid_os
370
+ + offs_dv[None, :]
438
371
  )
439
372
 
440
- offs_qk = cur_head[:, None] * stride_logic_h + (
441
- cur_batch_start_loc + start_n + offs_n[None, :]
373
+ tl.store(
374
+ Att_Out + offs_mid_o,
375
+ acc / e_sum[:, None],
376
+ mask=(mask_h[:, None]) & (mask_dv[None, :]),
442
377
  )
443
378
 
444
- qk = tl.load(
445
- logits + offs_qk,
446
- mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
447
- other=float("-inf"),
379
+ offs_mid_o_1 = (
380
+ cur_batch * stride_mid_ob
381
+ + cur_head * stride_mid_oh
382
+ + split_kv_id * stride_mid_os
383
+ + Lv
448
384
  )
449
385
 
450
- n_e_max = tl.maximum(tl.max(qk, 1), e_max)
451
- old_scale = tl.exp(e_max - n_e_max)
452
- p = tl.exp(qk - n_e_max[:, None])
453
- e_sum = e_sum * old_scale + tl.sum(p, 1)
454
- v = tl.load(
455
- v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
386
+ tl.store(
387
+ Att_Out + offs_mid_o_1,
388
+ e_max + tl.log(e_sum),
389
+ mask=mask_h,
456
390
  )
457
- p = p.to(v.dtype)
458
- acc = acc * old_scale[:, None] + tl.dot(p, v)
459
- e_max = n_e_max
460
-
461
- acc = acc / e_sum[:, None]
462
- off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
463
- out_ptrs = Out + off_o
464
- tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
465
391
 
466
392
 
467
393
  def _decode_grouped_att_m_fwd(
468
394
  q,
469
395
  k_buffer,
396
+ v_buffer,
470
397
  att_out,
471
398
  Req_to_tokens,
472
399
  B_req_idx,
473
- B_Start_Loc,
474
400
  B_Seqlen,
475
- max_len_in_batch,
401
+ num_kv_splits,
476
402
  sm_scale,
477
403
  logit_cap,
478
404
  ):
479
- BLOCK = 64
405
+ BLOCK = 32
480
406
  Lk = k_buffer.shape[-1]
407
+ Lv = v_buffer.shape[-1]
481
408
 
482
409
  if Lk == 576:
483
410
  BLOCK_DMODEL = 512
@@ -488,20 +415,19 @@ def _decode_grouped_att_m_fwd(
488
415
  else:
489
416
  BLOCK_DMODEL = triton.next_power_of_2(Lk)
490
417
  BLOCK_DPE = 0
418
+ BLOCK_DV = triton.next_power_of_2(Lv)
491
419
 
492
420
  batch, head_num = B_req_idx.shape[0], q.shape[1]
493
421
  kv_group_num = q.shape[1] // k_buffer.shape[1]
494
422
 
495
- BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
496
- SPLIT_K = 8
423
+ BLOCK_H = 16
424
+ NUM_KV_SPLITS = num_kv_splits
497
425
  grid = (
498
426
  batch,
499
427
  triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
500
- SPLIT_K,
428
+ NUM_KV_SPLITS,
501
429
  )
502
430
 
503
- num_warps = 4
504
-
505
431
  extra_kargs = {}
506
432
  if is_hip_:
507
433
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
@@ -511,10 +437,10 @@ def _decode_grouped_att_m_fwd(
511
437
  _fwd_grouped_kernel_stage1[grid](
512
438
  q,
513
439
  k_buffer,
440
+ v_buffer,
514
441
  sm_scale,
515
442
  Req_to_tokens,
516
443
  B_req_idx,
517
- B_Start_Loc,
518
444
  B_Seqlen,
519
445
  att_out,
520
446
  Req_to_tokens.stride(0),
@@ -522,41 +448,97 @@ def _decode_grouped_att_m_fwd(
522
448
  q.stride(1),
523
449
  k_buffer.stride(0),
524
450
  k_buffer.stride(1),
451
+ v_buffer.stride(0),
452
+ v_buffer.stride(1),
525
453
  att_out.stride(0),
454
+ att_out.stride(1),
455
+ att_out.stride(2),
526
456
  kv_group_num=kv_group_num,
527
457
  q_head_num=head_num,
528
458
  BLOCK_DMODEL=BLOCK_DMODEL,
529
459
  BLOCK_DPE=BLOCK_DPE,
460
+ BLOCK_DV=BLOCK_DV,
530
461
  BLOCK_N=BLOCK,
531
462
  BLOCK_H=BLOCK_H,
532
- SPLIT_K=SPLIT_K,
463
+ NUM_KV_SPLITS=NUM_KV_SPLITS,
533
464
  logit_cap=logit_cap,
534
- num_warps=num_warps,
535
- num_stages=1,
465
+ num_warps=4,
466
+ num_stages=2,
536
467
  Lk=Lk,
468
+ Lv=Lv,
537
469
  **extra_kargs,
538
470
  )
539
471
 
540
472
 
541
- def _decode_grouped_softmax_reducev_fwd(
473
+ @triton.jit
474
+ def _fwd_kernel_stage2(
475
+ Mid_O,
476
+ O,
477
+ B_Seqlen,
478
+ stride_mid_ob,
479
+ stride_mid_oh,
480
+ stride_mid_os,
481
+ stride_obs,
482
+ stride_oh,
483
+ NUM_KV_SPLITS: tl.constexpr,
484
+ BLOCK_DV: tl.constexpr,
485
+ Lv: tl.constexpr,
486
+ ):
487
+ cur_batch = tl.program_id(0)
488
+ cur_head = tl.program_id(1)
489
+
490
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
491
+
492
+ offs_d = tl.arange(0, BLOCK_DV)
493
+ mask_d = offs_d < Lv
494
+
495
+ e_sum = 0.0
496
+ e_max = -float("inf")
497
+ acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
498
+
499
+ offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
500
+ offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
501
+
502
+ for split_kv_id in range(0, NUM_KV_SPLITS):
503
+ kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
504
+ split_kv_start = kv_len_per_split * split_kv_id
505
+ split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len)
506
+
507
+ if split_kv_end > split_kv_start:
508
+ tv = tl.load(
509
+ Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0
510
+ )
511
+ tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
512
+ n_e_max = tl.maximum(tlogic, e_max)
513
+
514
+ old_scale = tl.exp(e_max - n_e_max)
515
+ acc *= old_scale
516
+ exp_logic = tl.exp(tlogic - n_e_max)
517
+ acc += exp_logic * tv
518
+
519
+ e_sum = e_sum * old_scale + exp_logic
520
+ e_max = n_e_max
521
+
522
+ tl.store(
523
+ O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
524
+ acc / e_sum,
525
+ mask=mask_d,
526
+ )
527
+
528
+
529
+ def _decode_softmax_reducev_fwd(
542
530
  logits,
543
- v_buffer,
531
+ q,
544
532
  o,
545
- req_to_tokens,
546
- b_req_idx,
547
- b_start_loc,
533
+ v_buffer,
548
534
  b_seq_len,
535
+ num_kv_splits,
549
536
  ):
550
- BLOCK = 128
551
- batch, head_num = b_seq_len.shape[0], logits.shape[0]
552
- kv_group_num = logits.shape[0] // v_buffer.shape[1]
553
- BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
554
- grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
555
-
556
- num_warps = 8
557
-
537
+ batch, head_num = q.shape[0], q.shape[1]
558
538
  Lv = v_buffer.shape[-1]
559
- BLOCK_DMODEL = triton.next_power_of_2(Lv)
539
+ BLOCK_DV = triton.next_power_of_2(Lv)
540
+
541
+ NUM_KV_SPLITS = num_kv_splits
560
542
 
561
543
  extra_kargs = {}
562
544
  if is_hip_:
@@ -564,28 +546,21 @@ def _decode_grouped_softmax_reducev_fwd(
564
546
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
565
547
  extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
566
548
 
567
- _fwd_grouped_kernel_stage2[grid](
549
+ grid = (batch, head_num)
550
+ _fwd_kernel_stage2[grid](
568
551
  logits,
569
- v_buffer,
570
552
  o,
571
- req_to_tokens,
572
- b_req_idx,
573
- b_start_loc,
574
553
  b_seq_len,
575
554
  logits.stride(0),
576
- v_buffer.stride(0),
577
- v_buffer.stride(1),
555
+ logits.stride(1),
556
+ logits.stride(2),
578
557
  o.stride(0),
579
558
  o.stride(1),
580
- req_to_tokens.stride(0),
581
- kv_group_num=kv_group_num,
582
- q_head_num=head_num,
583
- BLOCK_DMODEL=BLOCK_DMODEL,
584
- BLOCK_N=BLOCK,
585
- BLOCK_H=BLOCK_H,
559
+ NUM_KV_SPLITS=NUM_KV_SPLITS,
560
+ BLOCK_DV=BLOCK_DV,
586
561
  Lv=Lv,
587
- num_warps=num_warps,
588
- num_stages=1,
562
+ num_warps=4,
563
+ num_stages=2,
589
564
  **extra_kargs,
590
565
  )
591
566
 
@@ -597,34 +572,25 @@ def decode_attention_fwd_normal(
597
572
  o,
598
573
  req_to_token,
599
574
  b_req_idx,
600
- b_start_loc,
601
575
  b_seq_len,
602
576
  attn_logits,
603
- max_len_in_batch,
577
+ num_kv_splits,
604
578
  sm_scale,
605
579
  logit_cap=0.0,
606
580
  ):
607
581
  _decode_att_m_fwd(
608
582
  q,
609
583
  k_buffer,
584
+ v_buffer,
610
585
  attn_logits,
611
586
  req_to_token,
612
587
  b_req_idx,
613
- b_start_loc,
614
588
  b_seq_len,
615
- max_len_in_batch,
589
+ num_kv_splits,
616
590
  sm_scale,
617
591
  logit_cap,
618
592
  )
619
- _decode_softmax_reducev_fwd(
620
- attn_logits,
621
- v_buffer,
622
- o,
623
- req_to_token,
624
- b_req_idx,
625
- b_start_loc,
626
- b_seq_len,
627
- )
593
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
628
594
 
629
595
 
630
596
  def decode_attention_fwd_grouped(
@@ -634,34 +600,25 @@ def decode_attention_fwd_grouped(
634
600
  o,
635
601
  req_to_token,
636
602
  b_req_idx,
637
- b_start_loc,
638
603
  b_seq_len,
639
604
  attn_logits,
640
- max_len_in_batch,
605
+ num_kv_splits,
641
606
  sm_scale,
642
607
  logit_cap=0.0,
643
608
  ):
644
609
  _decode_grouped_att_m_fwd(
645
610
  q,
646
611
  k_buffer,
612
+ v_buffer,
647
613
  attn_logits,
648
614
  req_to_token,
649
615
  b_req_idx,
650
- b_start_loc,
651
616
  b_seq_len,
652
- max_len_in_batch,
617
+ num_kv_splits,
653
618
  sm_scale,
654
619
  logit_cap,
655
620
  )
656
- _decode_grouped_softmax_reducev_fwd(
657
- attn_logits,
658
- v_buffer,
659
- o,
660
- req_to_token,
661
- b_req_idx,
662
- b_start_loc,
663
- b_seq_len,
664
- )
621
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
665
622
 
666
623
 
667
624
  def decode_attention_fwd(
@@ -671,13 +628,13 @@ def decode_attention_fwd(
671
628
  o,
672
629
  req_to_token,
673
630
  b_req_idx,
674
- b_start_loc,
675
631
  b_seq_len,
676
632
  attn_logits,
677
- max_len_in_batch,
633
+ num_kv_splits,
678
634
  sm_scale,
679
635
  logit_cap=0.0,
680
636
  ):
637
+ assert num_kv_splits == attn_logits.shape[2]
681
638
  kv_group_num = q.shape[1] // v_buffer.shape[1]
682
639
 
683
640
  if kv_group_num == 1:
@@ -689,10 +646,9 @@ def decode_attention_fwd(
689
646
  o,
690
647
  req_to_token,
691
648
  b_req_idx,
692
- b_start_loc,
693
649
  b_seq_len,
694
650
  attn_logits,
695
- max_len_in_batch,
651
+ num_kv_splits,
696
652
  sm_scale,
697
653
  logit_cap,
698
654
  )
@@ -705,10 +661,9 @@ def decode_attention_fwd(
705
661
  o,
706
662
  req_to_token,
707
663
  b_req_idx,
708
- b_start_loc,
709
664
  b_seq_len,
710
665
  attn_logits,
711
- max_len_in_batch,
666
+ num_kv_splits,
712
667
  sm_scale,
713
668
  logit_cap,
714
669
  )