sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -26,7 +26,7 @@ import triton.language as tl
|
|
26
26
|
|
27
27
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
28
28
|
|
29
|
-
if global_server_args_dict.get("
|
29
|
+
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
|
30
30
|
REDUCE_TRITON_TYPE = tl.float32
|
31
31
|
REDUCE_TORCH_TYPE = torch.float32
|
32
32
|
else:
|
@@ -58,7 +58,6 @@ def _fwd_kernel_stage1(
|
|
58
58
|
att_stride_h,
|
59
59
|
kv_group_num: tl.constexpr,
|
60
60
|
BLOCK_DMODEL: tl.constexpr,
|
61
|
-
BLOCK_DPE: tl.constexpr,
|
62
61
|
BLOCK_N: tl.constexpr,
|
63
62
|
logit_cap: tl.constexpr,
|
64
63
|
):
|
@@ -78,10 +77,6 @@ def _fwd_kernel_stage1(
|
|
78
77
|
|
79
78
|
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
80
79
|
|
81
|
-
if BLOCK_DPE > 0:
|
82
|
-
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
83
|
-
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
|
84
|
-
|
85
80
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
86
81
|
|
87
82
|
block_stard_index = start_n * BLOCK_N
|
@@ -106,19 +101,6 @@ def _fwd_kernel_stage1(
|
|
106
101
|
other=0.0,
|
107
102
|
).to(REDUCE_TRITON_TYPE)
|
108
103
|
att_value = tl.sum(q[None, :] * k, 1)
|
109
|
-
if BLOCK_DPE > 0:
|
110
|
-
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
|
111
|
-
offs_buf_kpe = (
|
112
|
-
k_loc[:, None] * stride_buf_kbs
|
113
|
-
+ cur_kv_head * stride_buf_kh
|
114
|
-
+ offs_dpe[None, :]
|
115
|
-
)
|
116
|
-
kpe = tl.load(
|
117
|
-
K_Buffer + offs_buf_kpe,
|
118
|
-
mask=offs_n_new[:, None] < cur_batch_end_index,
|
119
|
-
other=0.0,
|
120
|
-
).to(REDUCE_TRITON_TYPE)
|
121
|
-
att_value += tl.sum(qpe[None, :] * kpe, 1)
|
122
104
|
att_value *= sm_scale
|
123
105
|
|
124
106
|
if logit_cap > 0:
|
@@ -214,14 +196,7 @@ def _decode_att_m_fwd(
|
|
214
196
|
# shape constraints
|
215
197
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
216
198
|
assert Lq == Lk
|
217
|
-
assert Lk in {16, 32, 64, 128, 256
|
218
|
-
|
219
|
-
if Lk == 576:
|
220
|
-
BLOCK_DMODEL = 512
|
221
|
-
BLOCK_DPE = 64
|
222
|
-
else:
|
223
|
-
BLOCK_DMODEL = Lk
|
224
|
-
BLOCK_DPE = 0
|
199
|
+
assert Lk in {16, 32, 64, 128, 256}
|
225
200
|
|
226
201
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
227
202
|
|
@@ -249,8 +224,7 @@ def _decode_att_m_fwd(
|
|
249
224
|
k_buffer.stride(1),
|
250
225
|
att_out.stride(0),
|
251
226
|
kv_group_num=kv_group_num,
|
252
|
-
BLOCK_DMODEL=
|
253
|
-
BLOCK_DPE=BLOCK_DPE,
|
227
|
+
BLOCK_DMODEL=Lk,
|
254
228
|
BLOCK_N=BLOCK,
|
255
229
|
logit_cap=logit_cap,
|
256
230
|
num_warps=num_warps,
|
@@ -296,6 +270,293 @@ def _decode_softmax_reducev_fwd(
|
|
296
270
|
)
|
297
271
|
|
298
272
|
|
273
|
+
@triton.jit
|
274
|
+
def _fwd_grouped_kernel_stage1(
|
275
|
+
Q,
|
276
|
+
K_Buffer,
|
277
|
+
sm_scale,
|
278
|
+
Req_to_tokens,
|
279
|
+
B_req_idx,
|
280
|
+
B_Start_Loc,
|
281
|
+
B_Seqlen,
|
282
|
+
Att_Out,
|
283
|
+
stride_req_to_tokens_b,
|
284
|
+
stride_qbs,
|
285
|
+
stride_qh,
|
286
|
+
stride_buf_kbs,
|
287
|
+
stride_buf_kh,
|
288
|
+
att_stride_h,
|
289
|
+
kv_group_num: tl.constexpr,
|
290
|
+
q_head_num: tl.constexpr,
|
291
|
+
BLOCK_DMODEL: tl.constexpr,
|
292
|
+
BLOCK_DPE: tl.constexpr,
|
293
|
+
BLOCK_N: tl.constexpr,
|
294
|
+
BLOCK_H: tl.constexpr,
|
295
|
+
logit_cap: tl.constexpr,
|
296
|
+
):
|
297
|
+
cur_batch = tl.program_id(0)
|
298
|
+
cur_kv_head = tl.program_id(1)
|
299
|
+
start_n = tl.program_id(2)
|
300
|
+
|
301
|
+
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
302
|
+
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
303
|
+
mask_h = mask_h & (cur_head < q_head_num)
|
304
|
+
|
305
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
306
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
307
|
+
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
308
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
309
|
+
|
310
|
+
cur_batch_start_index = 0
|
311
|
+
cur_batch_end_index = cur_batch_seq_len
|
312
|
+
|
313
|
+
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
|
314
|
+
|
315
|
+
if BLOCK_DPE > 0:
|
316
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
317
|
+
off_qpe = (
|
318
|
+
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
|
319
|
+
)
|
320
|
+
|
321
|
+
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
322
|
+
|
323
|
+
block_stard_index = start_n * BLOCK_N
|
324
|
+
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
325
|
+
|
326
|
+
for start_mark in range(0, block_mask, 1):
|
327
|
+
q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
|
328
|
+
REDUCE_TRITON_TYPE
|
329
|
+
)
|
330
|
+
offs_n_new = cur_batch_start_index + offs_n
|
331
|
+
k_loc = tl.load(
|
332
|
+
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
333
|
+
mask=offs_n_new < cur_batch_end_index,
|
334
|
+
other=0,
|
335
|
+
)
|
336
|
+
offs_buf_k = (
|
337
|
+
k_loc[None, :] * stride_buf_kbs
|
338
|
+
+ cur_kv_head * stride_buf_kh
|
339
|
+
+ offs_d[:, None]
|
340
|
+
)
|
341
|
+
k = tl.load(
|
342
|
+
K_Buffer + offs_buf_k,
|
343
|
+
mask=offs_n_new[None, :] < cur_batch_end_index,
|
344
|
+
other=0.0,
|
345
|
+
).to(REDUCE_TRITON_TYPE)
|
346
|
+
qk = tl.dot(q, k)
|
347
|
+
if BLOCK_DPE > 0:
|
348
|
+
qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(
|
349
|
+
REDUCE_TRITON_TYPE
|
350
|
+
)
|
351
|
+
offs_buf_kpe = (
|
352
|
+
k_loc[None, :] * stride_buf_kbs
|
353
|
+
+ cur_kv_head * stride_buf_kh
|
354
|
+
+ offs_dpe[:, None]
|
355
|
+
)
|
356
|
+
kpe = tl.load(
|
357
|
+
K_Buffer + offs_buf_kpe,
|
358
|
+
mask=offs_n_new[None, :] < cur_batch_end_index,
|
359
|
+
other=0.0,
|
360
|
+
).to(REDUCE_TRITON_TYPE)
|
361
|
+
qk += tl.dot(qpe, kpe)
|
362
|
+
qk *= sm_scale
|
363
|
+
|
364
|
+
if logit_cap > 0:
|
365
|
+
qk = logit_cap * tanh(qk / logit_cap)
|
366
|
+
|
367
|
+
offs_o = cur_head[:, None] * att_stride_h + (
|
368
|
+
cur_batch_in_all_start_index + offs_n[None, :]
|
369
|
+
)
|
370
|
+
|
371
|
+
tl.store(
|
372
|
+
Att_Out + offs_o,
|
373
|
+
qk,
|
374
|
+
mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index),
|
375
|
+
)
|
376
|
+
|
377
|
+
|
378
|
+
@triton.jit
|
379
|
+
def _fwd_grouped_kernel_stage2(
|
380
|
+
Logics,
|
381
|
+
V_Buffer,
|
382
|
+
Out,
|
383
|
+
Req_to_tokens,
|
384
|
+
B_req_idx,
|
385
|
+
B_Start_Loc,
|
386
|
+
B_Seqlen,
|
387
|
+
stride_logic_h,
|
388
|
+
stride_buf_vbs,
|
389
|
+
stride_buf_vh,
|
390
|
+
stride_obs,
|
391
|
+
stride_oh,
|
392
|
+
stride_req_to_token_b,
|
393
|
+
kv_group_num: tl.constexpr,
|
394
|
+
q_head_num: tl.constexpr,
|
395
|
+
BLOCK_DMODEL: tl.constexpr,
|
396
|
+
BLOCK_N: tl.constexpr,
|
397
|
+
BLOCK_H: tl.constexpr,
|
398
|
+
):
|
399
|
+
cur_batch = tl.program_id(0)
|
400
|
+
cur_kv_head = tl.program_id(1)
|
401
|
+
|
402
|
+
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
403
|
+
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
404
|
+
mask_h = mask_h & (cur_head < q_head_num)
|
405
|
+
|
406
|
+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
407
|
+
cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)
|
408
|
+
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
|
409
|
+
|
410
|
+
offs_n = tl.arange(0, BLOCK_N)
|
411
|
+
offs_d = tl.arange(0, BLOCK_DMODEL)
|
412
|
+
|
413
|
+
offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :]
|
414
|
+
v_ptrs = V_Buffer + offs_buf_v
|
415
|
+
|
416
|
+
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
417
|
+
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
418
|
+
acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32)
|
419
|
+
|
420
|
+
for start_n in range(0, cur_batch_seq_len, BLOCK_N):
|
421
|
+
start_n = tl.multiple_of(start_n, BLOCK_N)
|
422
|
+
v_index = tl.load(
|
423
|
+
Req_to_tokens
|
424
|
+
+ cur_batch_req_idx * stride_req_to_token_b
|
425
|
+
+ (start_n + offs_n),
|
426
|
+
mask=(start_n + offs_n) < cur_batch_seq_len,
|
427
|
+
other=0,
|
428
|
+
)
|
429
|
+
|
430
|
+
offs_qk = cur_head[:, None] * stride_logic_h + (
|
431
|
+
cur_batch_start_loc + start_n + offs_n[None, :]
|
432
|
+
)
|
433
|
+
|
434
|
+
qk = tl.load(
|
435
|
+
Logics + offs_qk,
|
436
|
+
mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len),
|
437
|
+
other=float("-inf"),
|
438
|
+
)
|
439
|
+
|
440
|
+
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
441
|
+
old_scale = tl.exp(e_max - n_e_max)
|
442
|
+
p = tl.exp(qk - n_e_max[:, None])
|
443
|
+
e_sum = e_sum * old_scale + tl.sum(p, 1)
|
444
|
+
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
|
445
|
+
p = p.to(v.dtype)
|
446
|
+
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
447
|
+
e_max = n_e_max
|
448
|
+
|
449
|
+
acc = acc / e_sum[:, None]
|
450
|
+
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
451
|
+
out_ptrs = Out + off_o
|
452
|
+
tl.store(out_ptrs, acc, mask=mask_h[:, None])
|
453
|
+
|
454
|
+
|
455
|
+
def _decode_grouped_att_m_fwd(
|
456
|
+
q,
|
457
|
+
k_buffer,
|
458
|
+
att_out,
|
459
|
+
Req_to_tokens,
|
460
|
+
B_req_idx,
|
461
|
+
B_Start_Loc,
|
462
|
+
B_Seqlen,
|
463
|
+
max_len_in_batch,
|
464
|
+
sm_scale,
|
465
|
+
logit_cap,
|
466
|
+
):
|
467
|
+
BLOCK = 32
|
468
|
+
# shape constraints
|
469
|
+
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
470
|
+
assert Lq == Lk
|
471
|
+
assert Lk in {16, 32, 64, 128, 256, 576}
|
472
|
+
|
473
|
+
if Lk == 576:
|
474
|
+
BLOCK_DMODEL = 512
|
475
|
+
BLOCK_DPE = 64
|
476
|
+
else:
|
477
|
+
BLOCK_DMODEL = Lk
|
478
|
+
BLOCK_DPE = 0
|
479
|
+
|
480
|
+
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
481
|
+
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
482
|
+
|
483
|
+
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
484
|
+
grid = (
|
485
|
+
batch,
|
486
|
+
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
487
|
+
triton.cdiv(max_len_in_batch, BLOCK),
|
488
|
+
)
|
489
|
+
|
490
|
+
num_warps = 4
|
491
|
+
|
492
|
+
_fwd_grouped_kernel_stage1[grid](
|
493
|
+
q,
|
494
|
+
k_buffer,
|
495
|
+
sm_scale,
|
496
|
+
Req_to_tokens,
|
497
|
+
B_req_idx,
|
498
|
+
B_Start_Loc,
|
499
|
+
B_Seqlen,
|
500
|
+
att_out,
|
501
|
+
Req_to_tokens.stride(0),
|
502
|
+
q.stride(0),
|
503
|
+
q.stride(1),
|
504
|
+
k_buffer.stride(0),
|
505
|
+
k_buffer.stride(1),
|
506
|
+
att_out.stride(0),
|
507
|
+
kv_group_num=kv_group_num,
|
508
|
+
q_head_num=head_num,
|
509
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
510
|
+
BLOCK_DPE=BLOCK_DPE,
|
511
|
+
BLOCK_N=BLOCK,
|
512
|
+
BLOCK_H=BLOCK_H,
|
513
|
+
logit_cap=logit_cap,
|
514
|
+
num_warps=num_warps,
|
515
|
+
num_stages=1,
|
516
|
+
)
|
517
|
+
|
518
|
+
|
519
|
+
def _decode_grouped_softmax_reducev_fwd(
|
520
|
+
logics,
|
521
|
+
v_buffer,
|
522
|
+
o,
|
523
|
+
req_to_tokens,
|
524
|
+
b_req_idx,
|
525
|
+
b_start_loc,
|
526
|
+
b_seq_len,
|
527
|
+
):
|
528
|
+
BLOCK = 128
|
529
|
+
batch, head_num = b_seq_len.shape[0], logics.shape[0]
|
530
|
+
kv_group_num = logics.shape[0] // v_buffer.shape[1]
|
531
|
+
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
532
|
+
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
533
|
+
|
534
|
+
num_warps = 8
|
535
|
+
|
536
|
+
_fwd_grouped_kernel_stage2[grid](
|
537
|
+
logics,
|
538
|
+
v_buffer,
|
539
|
+
o,
|
540
|
+
req_to_tokens,
|
541
|
+
b_req_idx,
|
542
|
+
b_start_loc,
|
543
|
+
b_seq_len,
|
544
|
+
logics.stride(0),
|
545
|
+
v_buffer.stride(0),
|
546
|
+
v_buffer.stride(1),
|
547
|
+
o.stride(0),
|
548
|
+
o.stride(1),
|
549
|
+
req_to_tokens.stride(0),
|
550
|
+
kv_group_num=kv_group_num,
|
551
|
+
q_head_num=head_num,
|
552
|
+
BLOCK_DMODEL=v_buffer.shape[-1],
|
553
|
+
BLOCK_N=BLOCK,
|
554
|
+
BLOCK_H=BLOCK_H,
|
555
|
+
num_warps=num_warps,
|
556
|
+
num_stages=1,
|
557
|
+
)
|
558
|
+
|
559
|
+
|
299
560
|
def decode_attention_fwd(
|
300
561
|
q,
|
301
562
|
k_buffer,
|
@@ -316,24 +577,51 @@ def decode_attention_fwd(
|
|
316
577
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
317
578
|
)
|
318
579
|
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
580
|
+
kv_group_num = q.shape[1] // v_buffer.shape[1]
|
581
|
+
|
582
|
+
if kv_group_num == 1:
|
583
|
+
# MHA
|
584
|
+
_decode_att_m_fwd(
|
585
|
+
q,
|
586
|
+
k_buffer,
|
587
|
+
att_m,
|
588
|
+
req_to_token,
|
589
|
+
b_req_idx,
|
590
|
+
b_start_loc,
|
591
|
+
b_seq_len,
|
592
|
+
max_len_in_batch,
|
593
|
+
sm_scale,
|
594
|
+
logit_cap,
|
595
|
+
)
|
596
|
+
_decode_softmax_reducev_fwd(
|
597
|
+
att_m,
|
598
|
+
v_buffer,
|
599
|
+
o,
|
600
|
+
req_to_token,
|
601
|
+
b_req_idx,
|
602
|
+
b_start_loc,
|
603
|
+
b_seq_len,
|
604
|
+
)
|
605
|
+
else:
|
606
|
+
# GQA/MQA/MLA
|
607
|
+
_decode_grouped_att_m_fwd(
|
608
|
+
q,
|
609
|
+
k_buffer,
|
610
|
+
att_m,
|
611
|
+
req_to_token,
|
612
|
+
b_req_idx,
|
613
|
+
b_start_loc,
|
614
|
+
b_seq_len,
|
615
|
+
max_len_in_batch,
|
616
|
+
sm_scale,
|
617
|
+
logit_cap,
|
618
|
+
)
|
619
|
+
_decode_grouped_softmax_reducev_fwd(
|
620
|
+
att_m,
|
621
|
+
v_buffer,
|
622
|
+
o,
|
623
|
+
req_to_token,
|
624
|
+
b_req_idx,
|
625
|
+
b_start_loc,
|
626
|
+
b_seq_len,
|
627
|
+
)
|
@@ -275,7 +275,9 @@ def extend_attention_fwd(
|
|
275
275
|
BLOCK_DPE = 0
|
276
276
|
BLOCK_DV = Lv
|
277
277
|
|
278
|
-
if CUDA_CAPABILITY[0] >=
|
278
|
+
if CUDA_CAPABILITY[0] >= 9:
|
279
|
+
BLOCK_M, BLOCK_N = (128, 64)
|
280
|
+
elif CUDA_CAPABILITY[0] >= 8:
|
279
281
|
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
280
282
|
else:
|
281
283
|
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
|
@@ -0,0 +1 @@
|
|
1
|
+
from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase
|