sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -15,14 +15,14 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""
|
17
17
|
Memory-efficient attention for prefill.
|
18
|
-
It
|
18
|
+
It supports page size = 1 and prefill with KV cache (i.e. extend).
|
19
19
|
"""
|
20
20
|
|
21
21
|
import torch
|
22
22
|
import triton
|
23
23
|
import triton.language as tl
|
24
24
|
|
25
|
-
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
25
|
+
from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd
|
26
26
|
|
27
27
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
28
28
|
|
@@ -61,12 +61,14 @@ def _fwd_kernel(
|
|
61
61
|
stride_buf_vbs,
|
62
62
|
stride_buf_vh,
|
63
63
|
stride_req_to_tokens_b,
|
64
|
+
logit_cap: tl.constexpr,
|
65
|
+
Lq: tl.constexpr,
|
66
|
+
Lv: tl.constexpr,
|
64
67
|
BLOCK_DMODEL: tl.constexpr,
|
65
68
|
BLOCK_DPE: tl.constexpr,
|
66
69
|
BLOCK_DV: tl.constexpr,
|
67
70
|
BLOCK_M: tl.constexpr,
|
68
71
|
BLOCK_N: tl.constexpr,
|
69
|
-
logit_cap: tl.constexpr,
|
70
72
|
):
|
71
73
|
cur_seq = tl.program_id(0)
|
72
74
|
cur_head = tl.program_id(1)
|
@@ -86,13 +88,18 @@ def _fwd_kernel(
|
|
86
88
|
offs_m = tl.arange(0, BLOCK_M)
|
87
89
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
88
90
|
|
91
|
+
mask_d = offs_d < Lq
|
92
|
+
mask_dv = offs_dv < Lv
|
93
|
+
|
89
94
|
offs_q = (
|
90
95
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
91
96
|
* stride_qbs
|
92
97
|
+ cur_head * stride_qh
|
93
98
|
+ offs_d[None, :]
|
94
99
|
)
|
95
|
-
q = tl.load(
|
100
|
+
q = tl.load(
|
101
|
+
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
102
|
+
)
|
96
103
|
|
97
104
|
if BLOCK_DPE > 0:
|
98
105
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
@@ -104,7 +111,7 @@ def _fwd_kernel(
|
|
104
111
|
)
|
105
112
|
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
106
113
|
|
107
|
-
#
|
114
|
+
# stage 1: compute scores with prefix
|
108
115
|
offs_n = tl.arange(0, BLOCK_N)
|
109
116
|
|
110
117
|
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
@@ -125,7 +132,9 @@ def _fwd_kernel(
|
|
125
132
|
+ cur_kv_head * stride_buf_kh
|
126
133
|
+ offs_d[:, None]
|
127
134
|
)
|
128
|
-
k = tl.load(
|
135
|
+
k = tl.load(
|
136
|
+
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
137
|
+
)
|
129
138
|
|
130
139
|
qk = tl.dot(q.to(k.dtype), k)
|
131
140
|
if BLOCK_DPE > 0:
|
@@ -157,13 +166,15 @@ def _fwd_kernel(
|
|
157
166
|
+ cur_kv_head * stride_buf_vh
|
158
167
|
+ offs_dv[None, :]
|
159
168
|
)
|
160
|
-
v = tl.load(
|
169
|
+
v = tl.load(
|
170
|
+
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
171
|
+
)
|
161
172
|
p = p.to(v.dtype)
|
162
173
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
163
174
|
|
164
175
|
e_max = n_e_max
|
165
176
|
|
166
|
-
#
|
177
|
+
# stage 2: compute the trianlge part
|
167
178
|
|
168
179
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
169
180
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
@@ -176,7 +187,9 @@ def _fwd_kernel(
|
|
176
187
|
+ cur_kv_head * stride_kh
|
177
188
|
+ offs_d[:, None]
|
178
189
|
)
|
179
|
-
k = tl.load(
|
190
|
+
k = tl.load(
|
191
|
+
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
192
|
+
)
|
180
193
|
|
181
194
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
182
195
|
if BLOCK_DPE > 0:
|
@@ -214,7 +227,9 @@ def _fwd_kernel(
|
|
214
227
|
+ cur_kv_head * stride_vh
|
215
228
|
+ offs_dv[None, :]
|
216
229
|
)
|
217
|
-
v = tl.load(
|
230
|
+
v = tl.load(
|
231
|
+
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
232
|
+
)
|
218
233
|
p = p.to(v.dtype)
|
219
234
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
220
235
|
|
@@ -226,7 +241,9 @@ def _fwd_kernel(
|
|
226
241
|
+ cur_head * stride_oh
|
227
242
|
+ offs_dv[None, :]
|
228
243
|
)
|
229
|
-
tl.store(
|
244
|
+
tl.store(
|
245
|
+
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
246
|
+
)
|
230
247
|
|
231
248
|
|
232
249
|
def extend_attention_fwd(
|
@@ -238,39 +255,34 @@ def extend_attention_fwd(
|
|
238
255
|
v_buffer,
|
239
256
|
req_to_tokens,
|
240
257
|
b_req_idx,
|
241
|
-
b_start_loc,
|
242
258
|
b_seq_len,
|
243
|
-
b_seq_len_prefix,
|
244
|
-
b_start_loc_extend,
|
245
259
|
b_seq_len_extend,
|
246
|
-
|
260
|
+
b_start_loc_extend,
|
247
261
|
max_len_extend,
|
248
262
|
sm_scale=None,
|
249
|
-
logit_cap
|
263
|
+
logit_cap=0.0,
|
250
264
|
):
|
251
265
|
"""
|
252
266
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
253
267
|
|
254
268
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
255
269
|
"""
|
256
|
-
Lq, Lk, Lv
|
270
|
+
Lq, Lk, Lv = (
|
257
271
|
q_extend.shape[-1],
|
258
272
|
k_extend.shape[-1],
|
259
273
|
v_extend.shape[-1],
|
260
|
-
o_extend.shape[-1],
|
261
274
|
)
|
262
275
|
|
263
|
-
assert Lq == Lk and Lv == Lo
|
264
|
-
assert Lq in {16, 32, 64, 128, 256, 576}
|
265
|
-
assert Lv in {16, 32, 64, 128, 256, 512}
|
266
|
-
|
267
276
|
if Lq == 576:
|
268
277
|
BLOCK_DMODEL = 512
|
269
278
|
BLOCK_DPE = 64
|
279
|
+
elif Lq == 288:
|
280
|
+
BLOCK_DMODEL = 256
|
281
|
+
BLOCK_DPE = 32
|
270
282
|
else:
|
271
|
-
BLOCK_DMODEL = Lq
|
283
|
+
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
272
284
|
BLOCK_DPE = 0
|
273
|
-
BLOCK_DV = Lv
|
285
|
+
BLOCK_DV = triton.next_power_of_2(Lv)
|
274
286
|
|
275
287
|
if CUDA_CAPABILITY[0] >= 9:
|
276
288
|
if Lq <= 256:
|
@@ -287,7 +299,7 @@ def extend_attention_fwd(
|
|
287
299
|
else:
|
288
300
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
289
301
|
|
290
|
-
sm_scale = 1.0 / (Lq**0.5)
|
302
|
+
sm_scale = sm_scale or 1.0 / (Lq**0.5)
|
291
303
|
batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
|
292
304
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
293
305
|
|
@@ -322,25 +334,24 @@ def extend_attention_fwd(
|
|
322
334
|
v_buffer.stride(0),
|
323
335
|
v_buffer.stride(1),
|
324
336
|
req_to_tokens.stride(0),
|
337
|
+
logit_cap=logit_cap,
|
325
338
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
326
339
|
BLOCK_DPE=BLOCK_DPE,
|
327
340
|
BLOCK_DV=BLOCK_DV,
|
328
341
|
BLOCK_M=BLOCK_M,
|
329
342
|
BLOCK_N=BLOCK_N,
|
343
|
+
Lq=Lq,
|
344
|
+
Lv=Lv,
|
330
345
|
num_warps=num_warps,
|
331
346
|
num_stages=num_stages,
|
332
|
-
logit_cap=logit_cap,
|
333
347
|
)
|
334
348
|
|
335
349
|
|
336
350
|
def redundant_attention(
|
337
351
|
q_extend,
|
338
|
-
k_extend,
|
339
|
-
v_extend,
|
340
352
|
o_extend,
|
341
353
|
k_buffer,
|
342
354
|
v_buffer,
|
343
|
-
req_to_tokens,
|
344
355
|
b_req_idx,
|
345
356
|
b_start_loc,
|
346
357
|
b_seq_len,
|
@@ -371,106 +382,3 @@ def redundant_attention(
|
|
371
382
|
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
|
372
383
|
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
|
373
384
|
pt += cur_seq_len_extend
|
374
|
-
|
375
|
-
|
376
|
-
def test():
|
377
|
-
torch.manual_seed(0)
|
378
|
-
|
379
|
-
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
|
380
|
-
dtype = torch.float16
|
381
|
-
|
382
|
-
b_seq_len_prefix = torch.randint(
|
383
|
-
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
384
|
-
)
|
385
|
-
b_seq_len_extend = torch.randint(
|
386
|
-
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
|
387
|
-
)
|
388
|
-
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
389
|
-
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
390
|
-
|
391
|
-
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
392
|
-
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
|
393
|
-
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
394
|
-
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
395
|
-
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
396
|
-
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
397
|
-
for i in range(B):
|
398
|
-
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
399
|
-
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
400
|
-
)
|
401
|
-
|
402
|
-
total_token_num = torch.sum(b_seq_len).item()
|
403
|
-
extend_token_num = torch.sum(b_seq_len_extend).item()
|
404
|
-
k_buffer = torch.empty(
|
405
|
-
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
406
|
-
).normal_(mean=0.1, std=0.2)
|
407
|
-
v_buffer = torch.empty(
|
408
|
-
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
|
409
|
-
).normal_(mean=0.1, std=0.2)
|
410
|
-
|
411
|
-
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
412
|
-
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
|
413
|
-
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
414
|
-
for i in range(B):
|
415
|
-
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
416
|
-
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
417
|
-
extend_start = b_start_loc_extend[i]
|
418
|
-
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
419
|
-
k_extend[extend_start:extend_end] = k_buffer[
|
420
|
-
extend_start_in_buffer:extend_end_in_buffer
|
421
|
-
]
|
422
|
-
v_extend[extend_start:extend_end] = v_buffer[
|
423
|
-
extend_start_in_buffer:extend_end_in_buffer
|
424
|
-
]
|
425
|
-
q_extend[extend_start:extend_end] = torch.empty(
|
426
|
-
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
|
427
|
-
).normal_(mean=0.1, std=0.2)
|
428
|
-
|
429
|
-
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
430
|
-
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
431
|
-
|
432
|
-
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
433
|
-
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
434
|
-
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
435
|
-
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
436
|
-
extend_attention_fwd(
|
437
|
-
q_extend,
|
438
|
-
k_extend,
|
439
|
-
v_extend,
|
440
|
-
o_extend,
|
441
|
-
k_buffer,
|
442
|
-
v_buffer,
|
443
|
-
req_to_tokens,
|
444
|
-
b_req_idx,
|
445
|
-
b_start_loc,
|
446
|
-
b_seq_len,
|
447
|
-
b_seq_len_prefix,
|
448
|
-
b_start_loc_extend,
|
449
|
-
b_seq_len_extend,
|
450
|
-
max_len_in_batch,
|
451
|
-
max_len_extend,
|
452
|
-
)
|
453
|
-
|
454
|
-
redundant_attention(
|
455
|
-
q_extend,
|
456
|
-
k_extend,
|
457
|
-
v_extend,
|
458
|
-
o_redundant,
|
459
|
-
k_buffer,
|
460
|
-
v_buffer,
|
461
|
-
req_to_tokens,
|
462
|
-
b_req_idx,
|
463
|
-
b_start_loc,
|
464
|
-
b_seq_len,
|
465
|
-
b_seq_len_prefix,
|
466
|
-
max_len_in_batch,
|
467
|
-
)
|
468
|
-
|
469
|
-
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
|
470
|
-
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
|
471
|
-
|
472
|
-
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
|
473
|
-
|
474
|
-
|
475
|
-
if __name__ == "__main__":
|
476
|
-
test()
|
@@ -48,6 +48,7 @@ def _fwd_kernel(
|
|
48
48
|
BLOCK_M: tl.constexpr,
|
49
49
|
BLOCK_DMODEL: tl.constexpr,
|
50
50
|
BLOCK_N: tl.constexpr,
|
51
|
+
Lk: tl.constexpr,
|
51
52
|
):
|
52
53
|
cur_batch = tl.program_id(0)
|
53
54
|
cur_head = tl.program_id(1)
|
@@ -72,7 +73,11 @@ def _fwd_kernel(
|
|
72
73
|
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
73
74
|
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
74
75
|
|
75
|
-
|
76
|
+
mask_d = offs_d < Lk
|
77
|
+
|
78
|
+
q = tl.load(
|
79
|
+
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
|
80
|
+
)
|
76
81
|
|
77
82
|
k_ptrs = K + off_k
|
78
83
|
v_ptrs = V + off_v
|
@@ -89,7 +94,7 @@ def _fwd_kernel(
|
|
89
94
|
# -- compute qk ----
|
90
95
|
k = tl.load(
|
91
96
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
92
|
-
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
97
|
+
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
93
98
|
other=0.0,
|
94
99
|
)
|
95
100
|
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
@@ -118,7 +123,7 @@ def _fwd_kernel(
|
|
118
123
|
# update acc
|
119
124
|
v = tl.load(
|
120
125
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
121
|
-
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
126
|
+
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
122
127
|
other=0.0,
|
123
128
|
)
|
124
129
|
|
@@ -134,7 +139,9 @@ def _fwd_kernel(
|
|
134
139
|
+ offs_d[None, :]
|
135
140
|
)
|
136
141
|
out_ptrs = Out + off_o
|
137
|
-
tl.store(
|
142
|
+
tl.store(
|
143
|
+
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
144
|
+
)
|
138
145
|
|
139
146
|
|
140
147
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
@@ -144,8 +151,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
144
151
|
BLOCK = 64
|
145
152
|
|
146
153
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
147
|
-
assert Lq == Lk and Lk == Lv
|
148
|
-
assert Lk in {16, 32, 64, 128, 256}
|
149
154
|
|
150
155
|
sm_scale = 1.0 / (Lq**0.5)
|
151
156
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
@@ -172,8 +177,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|
172
177
|
o.stride(1),
|
173
178
|
kv_group_num=kv_group_num,
|
174
179
|
BLOCK_M=BLOCK,
|
175
|
-
BLOCK_DMODEL=Lk,
|
180
|
+
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
176
181
|
BLOCK_N=BLOCK,
|
177
182
|
num_warps=num_warps,
|
178
183
|
num_stages=1,
|
184
|
+
Lk=Lk,
|
179
185
|
)
|