sglang 0.4.2.post4__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/global_config.py +2 -0
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/attention/flashinfer_backend.py +235 -110
- sglang/srt/layers/attention/triton_backend.py +358 -72
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/linear.py +12 -5
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +51 -5
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
- sglang/srt/layers/quantization/fp8_kernel.py +123 -17
- sglang/srt/layers/quantization/fp8_utils.py +33 -4
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +16 -3
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/tokenizer_manager.py +6 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +12 -1
- sglang/srt/model_executor/model_runner.py +12 -2
- sglang/srt/models/deepseek_v2.py +17 -7
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/METADATA +4 -3
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/RECORD +57 -41
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post4.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
sglang/global_config.py
CHANGED
sglang/srt/entrypoints/engine.py
CHANGED
@@ -297,7 +297,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
297
297
|
# Set global environments
|
298
298
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
299
299
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
300
|
-
os.environ["NCCL_NVLS_ENABLE"] =
|
300
|
+
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
301
301
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
302
302
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
303
303
|
|
@@ -317,7 +317,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
317
317
|
if server_args.attention_backend == "flashinfer":
|
318
318
|
assert_pkg_version(
|
319
319
|
"flashinfer_python",
|
320
|
-
"0.2.
|
320
|
+
"0.2.1.post1",
|
321
321
|
"Please uninstall the old version and "
|
322
322
|
"reinstall the latest version by following the instructions "
|
323
323
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import math
|
10
11
|
import os
|
11
12
|
from dataclasses import dataclass
|
12
13
|
from enum import Enum, auto
|
@@ -20,6 +21,7 @@ import triton.language as tl
|
|
20
21
|
from sglang.global_config import global_config
|
21
22
|
from sglang.srt.layers.attention import AttentionBackend
|
22
23
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
26
|
from sglang.srt.utils import is_flashinfer_available
|
25
27
|
|
@@ -35,7 +37,7 @@ if is_flashinfer_available():
|
|
35
37
|
BatchPrefillWithRaggedKVCacheWrapper,
|
36
38
|
)
|
37
39
|
from flashinfer.cascade import merge_state
|
38
|
-
from flashinfer.
|
40
|
+
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
39
41
|
|
40
42
|
|
41
43
|
class WrapperDispatch(Enum):
|
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
|
|
45
47
|
|
46
48
|
@dataclass
|
47
49
|
class DecodeMetadata:
|
48
|
-
decode_wrappers: List[
|
50
|
+
decode_wrappers: List[
|
51
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
52
|
+
]
|
49
53
|
|
50
54
|
|
51
55
|
@dataclass
|
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
103
107
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
104
108
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
105
109
|
|
110
|
+
self.enable_flashinfer_mla = False
|
111
|
+
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
112
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
113
|
+
self.enable_flashinfer_mla = True
|
114
|
+
global_config.enable_flashinfer_mla = True
|
115
|
+
|
106
116
|
# Allocate buffers
|
107
117
|
global global_workspace_buffer
|
108
118
|
if global_workspace_buffer is None:
|
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
120
130
|
)
|
121
131
|
for _ in range(self.num_wrappers)
|
122
132
|
]
|
133
|
+
if self.enable_flashinfer_mla:
|
134
|
+
self.qo_indptr = [
|
135
|
+
torch.zeros(
|
136
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
137
|
+
)
|
138
|
+
for _ in range(self.num_wrappers)
|
139
|
+
]
|
123
140
|
else:
|
124
141
|
assert self.num_wrappers == 1
|
125
142
|
self.kv_indptr = [kv_indptr_buf]
|
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
153
170
|
self.prefill_wrappers_verify.append(
|
154
171
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
155
172
|
)
|
156
|
-
self.
|
157
|
-
|
158
|
-
self.workspace_buffer,
|
159
|
-
|
160
|
-
|
173
|
+
if self.enable_flashinfer_mla:
|
174
|
+
self.decode_wrappers.append(
|
175
|
+
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
self.decode_wrappers.append(
|
179
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
180
|
+
self.workspace_buffer,
|
181
|
+
"NHD",
|
182
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
183
|
+
)
|
161
184
|
)
|
162
|
-
)
|
163
185
|
|
164
186
|
# Create indices updater
|
165
187
|
if not skip_prefill:
|
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
274
296
|
if forward_mode.is_decode_or_idle():
|
275
297
|
decode_wrappers = []
|
276
298
|
for i in range(self.num_wrappers):
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
299
|
+
if self.enable_flashinfer_mla:
|
300
|
+
decode_wrappers.append(
|
301
|
+
BatchMLAPagedAttentionWrapper(
|
302
|
+
self.workspace_buffer,
|
303
|
+
use_cuda_graph=True,
|
304
|
+
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
|
305
|
+
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
|
306
|
+
kv_indices=self.cuda_graph_kv_indices[i],
|
307
|
+
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
308
|
+
backend="fa2",
|
309
|
+
)
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
decode_wrappers.append(
|
313
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
314
|
+
self.workspace_buffer,
|
315
|
+
"NHD",
|
316
|
+
use_cuda_graph=True,
|
317
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
318
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
319
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
320
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
321
|
+
:num_tokens
|
322
|
+
],
|
323
|
+
)
|
288
324
|
)
|
289
|
-
)
|
290
325
|
seq_lens_sum = seq_lens.sum().item()
|
291
326
|
self.indices_updater_decode.update(
|
292
327
|
req_pool_indices,
|
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
375
410
|
forward_batch: ForwardBatch,
|
376
411
|
save_kv_cache=True,
|
377
412
|
):
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
else forward_batch.encoder_out_cache_loc
|
385
|
-
)
|
413
|
+
if global_config.enable_flashinfer_mla:
|
414
|
+
cache_loc = (
|
415
|
+
forward_batch.out_cache_loc
|
416
|
+
if not layer.is_cross_attention
|
417
|
+
else forward_batch.encoder_out_cache_loc
|
418
|
+
)
|
386
419
|
|
387
|
-
|
420
|
+
logits_soft_cap = layer.logit_cap
|
388
421
|
|
389
|
-
|
390
|
-
if k is not None:
|
391
|
-
assert v is not None
|
392
|
-
if save_kv_cache:
|
393
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
394
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
395
|
-
)
|
396
|
-
|
397
|
-
o = prefill_wrapper_paged.forward(
|
398
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
399
|
-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
400
|
-
causal=not layer.is_cross_attention,
|
401
|
-
sm_scale=layer.scaling,
|
402
|
-
window_left=layer.sliding_window_size,
|
403
|
-
logits_soft_cap=logits_soft_cap,
|
404
|
-
k_scale=layer.k_scale,
|
405
|
-
v_scale=layer.v_scale,
|
406
|
-
)
|
407
|
-
else:
|
408
|
-
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
422
|
+
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
409
423
|
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
410
424
|
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
411
|
-
v.view(-1, layer.tp_v_head_num, layer.
|
425
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
412
426
|
causal=True,
|
413
427
|
sm_scale=layer.scaling,
|
414
428
|
logits_soft_cap=logits_soft_cap,
|
415
429
|
)
|
416
430
|
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
431
|
+
o = o1
|
432
|
+
|
433
|
+
if save_kv_cache:
|
434
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
435
|
+
layer,
|
436
|
+
cache_loc,
|
437
|
+
k,
|
438
|
+
v,
|
439
|
+
)
|
440
|
+
|
441
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
442
|
+
else:
|
443
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
444
|
+
self._get_wrapper_idx(layer)
|
445
|
+
]
|
446
|
+
cache_loc = (
|
447
|
+
forward_batch.out_cache_loc
|
448
|
+
if not layer.is_cross_attention
|
449
|
+
else forward_batch.encoder_out_cache_loc
|
450
|
+
)
|
451
|
+
|
452
|
+
logits_soft_cap = layer.logit_cap
|
453
|
+
|
454
|
+
if not self.forward_metadata.use_ragged:
|
455
|
+
if k is not None:
|
456
|
+
assert v is not None
|
457
|
+
if save_kv_cache:
|
458
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
459
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
460
|
+
)
|
461
|
+
|
462
|
+
o = prefill_wrapper_paged.forward(
|
421
463
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
422
464
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
423
|
-
causal=
|
465
|
+
causal=not layer.is_cross_attention,
|
424
466
|
sm_scale=layer.scaling,
|
425
|
-
|
467
|
+
window_left=layer.sliding_window_size,
|
468
|
+
logits_soft_cap=logits_soft_cap,
|
469
|
+
k_scale=layer.k_scale,
|
470
|
+
v_scale=layer.v_scale,
|
471
|
+
)
|
472
|
+
else:
|
473
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
474
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
475
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
476
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
477
|
+
causal=True,
|
478
|
+
sm_scale=layer.scaling,
|
479
|
+
logits_soft_cap=logits_soft_cap,
|
426
480
|
)
|
427
481
|
|
428
|
-
|
482
|
+
if self.forward_metadata.extend_no_prefix:
|
483
|
+
o = o1
|
484
|
+
else:
|
485
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
486
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
487
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
488
|
+
causal=False,
|
489
|
+
sm_scale=layer.scaling,
|
490
|
+
logits_soft_cap=layer.logit_cap,
|
491
|
+
)
|
429
492
|
|
430
|
-
|
431
|
-
|
432
|
-
|
433
|
-
|
493
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
494
|
+
|
495
|
+
if save_kv_cache:
|
496
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
497
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
498
|
+
)
|
434
499
|
|
435
|
-
|
500
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
436
501
|
|
437
502
|
def forward_decode(
|
438
503
|
self,
|
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
452
517
|
else forward_batch.encoder_out_cache_loc
|
453
518
|
)
|
454
519
|
|
455
|
-
if
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
520
|
+
if self.enable_flashinfer_mla:
|
521
|
+
if k is not None:
|
522
|
+
assert v is not None
|
523
|
+
if save_kv_cache:
|
524
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
525
|
+
layer,
|
526
|
+
cache_loc,
|
527
|
+
k,
|
528
|
+
v,
|
529
|
+
)
|
530
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
531
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
532
|
+
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
533
|
+
o = decode_wrapper.run(
|
534
|
+
reshaped_q[:, :, : layer.v_head_dim],
|
535
|
+
reshaped_q[:, :, layer.v_head_dim :],
|
536
|
+
reshaped_k[:, :, : layer.v_head_dim],
|
537
|
+
reshaped_k[:, :, layer.v_head_dim :],
|
538
|
+
)
|
461
539
|
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
540
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
541
|
+
else:
|
542
|
+
if k is not None:
|
543
|
+
assert v is not None
|
544
|
+
if save_kv_cache:
|
545
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
546
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
547
|
+
)
|
470
548
|
|
471
|
-
|
549
|
+
o = decode_wrapper.forward(
|
550
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
551
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
552
|
+
sm_scale=layer.scaling,
|
553
|
+
logits_soft_cap=layer.logit_cap,
|
554
|
+
k_scale=layer.k_scale,
|
555
|
+
v_scale=layer.v_scale,
|
556
|
+
)
|
557
|
+
|
558
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
472
559
|
|
473
560
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
474
561
|
if self.num_wrappers == 1:
|
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
516
603
|
req_pool_indices: torch.Tensor,
|
517
604
|
seq_lens: torch.Tensor,
|
518
605
|
seq_lens_sum: int,
|
519
|
-
decode_wrappers: List[
|
606
|
+
decode_wrappers: List[
|
607
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
608
|
+
],
|
520
609
|
encoder_lens: Optional[torch.Tensor],
|
521
610
|
spec_info: Optional[SpecInfo],
|
522
611
|
):
|
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
528
617
|
req_pool_indices: torch.Tensor,
|
529
618
|
seq_lens: torch.Tensor,
|
530
619
|
seq_lens_sum: int,
|
531
|
-
decode_wrappers: List[
|
620
|
+
decode_wrappers: List[
|
621
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
622
|
+
],
|
532
623
|
encoder_lens: Optional[torch.Tensor],
|
533
624
|
spec_info: Optional[SpecInfo],
|
534
625
|
):
|
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
609
700
|
|
610
701
|
def call_begin_forward(
|
611
702
|
self,
|
612
|
-
wrapper:
|
703
|
+
wrapper: Union[
|
704
|
+
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
705
|
+
],
|
613
706
|
req_pool_indices: torch.Tensor,
|
614
707
|
paged_kernel_lens: torch.Tensor,
|
615
708
|
paged_kernel_lens_sum: int,
|
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
|
|
637
730
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
638
731
|
bs = kv_indptr.shape[0] - 1
|
639
732
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
733
|
+
if global_config.enable_flashinfer_mla:
|
734
|
+
sm_scale = 1.0 / math.sqrt(192)
|
735
|
+
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
736
|
+
kv_lens = paged_kernel_lens.to(torch.int32)
|
737
|
+
wrapper.plan(
|
738
|
+
q_indptr,
|
739
|
+
kv_indptr,
|
740
|
+
kv_indices,
|
741
|
+
kv_lens,
|
742
|
+
self.num_qo_heads,
|
743
|
+
512,
|
744
|
+
64,
|
745
|
+
1,
|
746
|
+
False,
|
747
|
+
sm_scale,
|
748
|
+
self.data_type,
|
749
|
+
self.data_type,
|
750
|
+
)
|
751
|
+
else:
|
752
|
+
wrapper.begin_forward(
|
753
|
+
kv_indptr,
|
754
|
+
kv_indices,
|
755
|
+
self.kv_last_page_len[:bs],
|
756
|
+
self.num_qo_heads,
|
757
|
+
self.num_kv_heads,
|
758
|
+
self.head_dim,
|
759
|
+
1,
|
760
|
+
data_type=self.data_type,
|
761
|
+
q_data_type=self.q_data_type,
|
762
|
+
non_blocking=True,
|
763
|
+
)
|
652
764
|
|
653
765
|
|
654
766
|
class FlashInferIndicesUpdaterPrefill:
|
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
|
|
857
969
|
|
858
970
|
# extend part
|
859
971
|
if use_ragged:
|
860
|
-
|
861
|
-
|
972
|
+
if global_config.enable_flashinfer_mla:
|
973
|
+
wrapper_ragged.begin_forward(
|
974
|
+
qo_indptr=qo_indptr,
|
975
|
+
kv_indptr=qo_indptr,
|
976
|
+
num_qo_heads=self.num_qo_heads,
|
977
|
+
num_kv_heads=self.num_kv_heads,
|
978
|
+
head_dim_qk=192,
|
979
|
+
head_dim_vo=128,
|
980
|
+
q_data_type=self.q_data_type,
|
981
|
+
)
|
982
|
+
else:
|
983
|
+
wrapper_ragged.begin_forward(
|
984
|
+
qo_indptr,
|
985
|
+
qo_indptr,
|
986
|
+
self.num_qo_heads,
|
987
|
+
self.num_kv_heads,
|
988
|
+
self.head_dim,
|
989
|
+
q_data_type=self.q_data_type,
|
990
|
+
)
|
991
|
+
|
992
|
+
if not global_config.enable_flashinfer_mla:
|
993
|
+
# cached part
|
994
|
+
wrapper_paged.begin_forward(
|
862
995
|
qo_indptr,
|
996
|
+
kv_indptr,
|
997
|
+
kv_indices,
|
998
|
+
self.kv_last_page_len[:bs],
|
863
999
|
self.num_qo_heads,
|
864
1000
|
self.num_kv_heads,
|
865
1001
|
self.head_dim,
|
1002
|
+
1,
|
866
1003
|
q_data_type=self.q_data_type,
|
1004
|
+
custom_mask=custom_mask,
|
1005
|
+
non_blocking=True,
|
867
1006
|
)
|
868
1007
|
|
869
|
-
# cached part
|
870
|
-
wrapper_paged.begin_forward(
|
871
|
-
qo_indptr,
|
872
|
-
kv_indptr,
|
873
|
-
kv_indices,
|
874
|
-
self.kv_last_page_len[:bs],
|
875
|
-
self.num_qo_heads,
|
876
|
-
self.num_kv_heads,
|
877
|
-
self.head_dim,
|
878
|
-
1,
|
879
|
-
q_data_type=self.q_data_type,
|
880
|
-
custom_mask=custom_mask,
|
881
|
-
non_blocking=True,
|
882
|
-
)
|
883
|
-
|
884
1008
|
|
885
1009
|
class FlashInferMultiStepDraftBackend:
|
886
1010
|
"""
|
@@ -947,7 +1071,7 @@ class FlashInferMultiStepDraftBackend:
|
|
947
1071
|
triton.next_power_of_2(bs),
|
948
1072
|
)
|
949
1073
|
|
950
|
-
for i in range(self.speculative_num_steps):
|
1074
|
+
for i in range(self.speculative_num_steps - 1):
|
951
1075
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
952
1076
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
953
1077
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
|
|
1163
1287
|
window_left,
|
1164
1288
|
logits_soft_cap,
|
1165
1289
|
head_dim,
|
1290
|
+
head_dim,
|
1166
1291
|
empty_q_data,
|
1167
1292
|
empty_kv_cache,
|
1168
1293
|
stream.cuda_stream,
|