sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/global_config.py +2 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/attention/flashinfer_backend.py +265 -147
- sglang/srt/layers/attention/triton_backend.py +358 -72
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/linear.py +12 -5
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +51 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
- sglang/srt/layers/quantization/fp8_kernel.py +123 -17
- sglang/srt/layers/quantization/fp8_utils.py +33 -4
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +16 -3
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/tokenizer_manager.py +6 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +16 -1
- sglang/srt/model_executor/model_runner.py +12 -2
- sglang/srt/models/deepseek_v2.py +17 -7
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +29 -8
- sglang/srt/utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|
7
7
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
8
8
|
"""
|
9
9
|
|
10
|
+
import math
|
10
11
|
import os
|
11
12
|
from dataclasses import dataclass
|
12
13
|
from enum import Enum, auto
|
@@ -20,6 +21,7 @@ import triton.language as tl
|
|
20
21
|
from sglang.global_config import global_config
|
21
22
|
from sglang.srt.layers.attention import AttentionBackend
|
22
23
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
26
|
from sglang.srt.utils import is_flashinfer_available
|
25
27
|
|
@@ -35,7 +37,7 @@ if is_flashinfer_available():
|
|
35
37
|
BatchPrefillWithRaggedKVCacheWrapper,
|
36
38
|
)
|
37
39
|
from flashinfer.cascade import merge_state
|
38
|
-
from flashinfer.
|
40
|
+
from flashinfer.mla import BatchMLAPagedAttentionWrapper
|
39
41
|
|
40
42
|
|
41
43
|
class WrapperDispatch(Enum):
|
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
|
|
45
47
|
|
46
48
|
@dataclass
|
47
49
|
class DecodeMetadata:
|
48
|
-
decode_wrappers: List[
|
50
|
+
decode_wrappers: List[
|
51
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
52
|
+
]
|
49
53
|
|
50
54
|
|
51
55
|
@dataclass
|
@@ -70,6 +74,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
70
74
|
):
|
71
75
|
super().__init__()
|
72
76
|
|
77
|
+
self.is_multimodal = model_runner.model_config.is_multimodal
|
78
|
+
|
73
79
|
# Parse constants
|
74
80
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
75
81
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
@@ -101,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
101
107
|
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
102
108
|
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
103
109
|
|
110
|
+
self.enable_flashinfer_mla = False
|
111
|
+
if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
112
|
+
if global_server_args_dict["enable_flashinfer_mla"]:
|
113
|
+
self.enable_flashinfer_mla = True
|
114
|
+
global_config.enable_flashinfer_mla = True
|
115
|
+
|
104
116
|
# Allocate buffers
|
105
117
|
global global_workspace_buffer
|
106
118
|
if global_workspace_buffer is None:
|
@@ -118,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
118
130
|
)
|
119
131
|
for _ in range(self.num_wrappers)
|
120
132
|
]
|
133
|
+
if self.enable_flashinfer_mla:
|
134
|
+
self.qo_indptr = [
|
135
|
+
torch.zeros(
|
136
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
137
|
+
)
|
138
|
+
for _ in range(self.num_wrappers)
|
139
|
+
]
|
121
140
|
else:
|
122
141
|
assert self.num_wrappers == 1
|
123
142
|
self.kv_indptr = [kv_indptr_buf]
|
@@ -130,12 +149,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
130
149
|
for _ in range(self.num_wrappers)
|
131
150
|
]
|
132
151
|
|
133
|
-
|
134
|
-
|
135
|
-
self.prefill_wrapper_ragged = (
|
136
|
-
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
137
|
-
if self.num_wrappers == 1
|
138
|
-
else None
|
152
|
+
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
153
|
+
self.workspace_buffer, "NHD"
|
139
154
|
)
|
140
155
|
|
141
156
|
# Two wrappers: one for sliding window attention and one for full attention.
|
@@ -155,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
155
170
|
self.prefill_wrappers_verify.append(
|
156
171
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
157
172
|
)
|
158
|
-
self.
|
159
|
-
|
160
|
-
self.workspace_buffer,
|
161
|
-
|
162
|
-
|
173
|
+
if self.enable_flashinfer_mla:
|
174
|
+
self.decode_wrappers.append(
|
175
|
+
BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
self.decode_wrappers.append(
|
179
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
180
|
+
self.workspace_buffer,
|
181
|
+
"NHD",
|
182
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
183
|
+
)
|
163
184
|
)
|
164
|
-
)
|
165
185
|
|
166
186
|
# Create indices updater
|
167
187
|
if not skip_prefill:
|
@@ -217,13 +237,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
217
237
|
else:
|
218
238
|
prefix_lens = forward_batch.extend_prefix_lens
|
219
239
|
|
220
|
-
|
221
|
-
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
222
|
-
use_ragged = True
|
223
|
-
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
224
|
-
else:
|
240
|
+
if self.is_multimodal:
|
225
241
|
use_ragged = False
|
226
242
|
extend_no_prefix = False
|
243
|
+
else:
|
244
|
+
use_ragged = True
|
245
|
+
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
227
246
|
|
228
247
|
self.indices_updater_prefill.update(
|
229
248
|
forward_batch.req_pool_indices,
|
@@ -277,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
277
296
|
if forward_mode.is_decode_or_idle():
|
278
297
|
decode_wrappers = []
|
279
298
|
for i in range(self.num_wrappers):
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
299
|
+
if self.enable_flashinfer_mla:
|
300
|
+
decode_wrappers.append(
|
301
|
+
BatchMLAPagedAttentionWrapper(
|
302
|
+
self.workspace_buffer,
|
303
|
+
use_cuda_graph=True,
|
304
|
+
qo_indptr=self.qo_indptr[i][: num_tokens + 1],
|
305
|
+
kv_indptr=self.kv_indptr[i][: num_tokens + 1],
|
306
|
+
kv_indices=self.cuda_graph_kv_indices[i],
|
307
|
+
kv_len_arr=self.kv_last_page_len[:num_tokens],
|
308
|
+
backend="fa2",
|
309
|
+
)
|
310
|
+
)
|
311
|
+
else:
|
312
|
+
decode_wrappers.append(
|
313
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
314
|
+
self.workspace_buffer,
|
315
|
+
"NHD",
|
316
|
+
use_cuda_graph=True,
|
317
|
+
use_tensor_cores=self.decode_use_tensor_cores,
|
318
|
+
paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
|
319
|
+
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
320
|
+
paged_kv_last_page_len_buffer=self.kv_last_page_len[
|
321
|
+
:num_tokens
|
322
|
+
],
|
323
|
+
)
|
291
324
|
)
|
292
|
-
)
|
293
325
|
seq_lens_sum = seq_lens.sum().item()
|
294
326
|
self.indices_updater_decode.update(
|
295
327
|
req_pool_indices,
|
@@ -378,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
378
410
|
forward_batch: ForwardBatch,
|
379
411
|
save_kv_cache=True,
|
380
412
|
):
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
else forward_batch.encoder_out_cache_loc
|
388
|
-
)
|
413
|
+
if global_config.enable_flashinfer_mla:
|
414
|
+
cache_loc = (
|
415
|
+
forward_batch.out_cache_loc
|
416
|
+
if not layer.is_cross_attention
|
417
|
+
else forward_batch.encoder_out_cache_loc
|
418
|
+
)
|
389
419
|
|
390
|
-
|
420
|
+
logits_soft_cap = layer.logit_cap
|
391
421
|
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
398
|
-
)
|
399
|
-
|
400
|
-
o = prefill_wrapper_paged.forward(
|
401
|
-
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
402
|
-
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
403
|
-
causal=not layer.is_cross_attention,
|
422
|
+
o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
|
423
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
424
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
425
|
+
v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
|
426
|
+
causal=True,
|
404
427
|
sm_scale=layer.scaling,
|
405
|
-
window_left=layer.sliding_window_size,
|
406
428
|
logits_soft_cap=logits_soft_cap,
|
407
|
-
k_scale=layer.k_scale,
|
408
|
-
v_scale=layer.v_scale,
|
409
429
|
)
|
430
|
+
|
431
|
+
o = o1
|
432
|
+
|
433
|
+
if save_kv_cache:
|
434
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
435
|
+
layer,
|
436
|
+
cache_loc,
|
437
|
+
k,
|
438
|
+
v,
|
439
|
+
)
|
440
|
+
|
441
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
410
442
|
else:
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
443
|
+
prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
|
444
|
+
self._get_wrapper_idx(layer)
|
445
|
+
]
|
446
|
+
cache_loc = (
|
447
|
+
forward_batch.out_cache_loc
|
448
|
+
if not layer.is_cross_attention
|
449
|
+
else forward_batch.encoder_out_cache_loc
|
418
450
|
)
|
419
451
|
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
452
|
+
logits_soft_cap = layer.logit_cap
|
453
|
+
|
454
|
+
if not self.forward_metadata.use_ragged:
|
455
|
+
if k is not None:
|
456
|
+
assert v is not None
|
457
|
+
if save_kv_cache:
|
458
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
459
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
460
|
+
)
|
461
|
+
|
462
|
+
o = prefill_wrapper_paged.forward(
|
424
463
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
425
464
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
426
|
-
causal=
|
465
|
+
causal=not layer.is_cross_attention,
|
466
|
+
sm_scale=layer.scaling,
|
467
|
+
window_left=layer.sliding_window_size,
|
468
|
+
logits_soft_cap=logits_soft_cap,
|
469
|
+
k_scale=layer.k_scale,
|
470
|
+
v_scale=layer.v_scale,
|
471
|
+
)
|
472
|
+
else:
|
473
|
+
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
474
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
475
|
+
k.view(-1, layer.tp_k_head_num, layer.head_dim),
|
476
|
+
v.view(-1, layer.tp_v_head_num, layer.head_dim),
|
477
|
+
causal=True,
|
427
478
|
sm_scale=layer.scaling,
|
428
|
-
logits_soft_cap=
|
479
|
+
logits_soft_cap=logits_soft_cap,
|
429
480
|
)
|
430
481
|
|
431
|
-
|
482
|
+
if self.forward_metadata.extend_no_prefix:
|
483
|
+
o = o1
|
484
|
+
else:
|
485
|
+
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
486
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
487
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
488
|
+
causal=False,
|
489
|
+
sm_scale=layer.scaling,
|
490
|
+
logits_soft_cap=layer.logit_cap,
|
491
|
+
)
|
432
492
|
|
433
|
-
|
434
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
435
|
-
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
436
|
-
)
|
493
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
437
494
|
|
438
|
-
|
495
|
+
if save_kv_cache:
|
496
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
497
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
498
|
+
)
|
499
|
+
|
500
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
439
501
|
|
440
502
|
def forward_decode(
|
441
503
|
self,
|
@@ -455,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
455
517
|
else forward_batch.encoder_out_cache_loc
|
456
518
|
)
|
457
519
|
|
458
|
-
if
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
520
|
+
if self.enable_flashinfer_mla:
|
521
|
+
if k is not None:
|
522
|
+
assert v is not None
|
523
|
+
if save_kv_cache:
|
524
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
525
|
+
layer,
|
526
|
+
cache_loc,
|
527
|
+
k,
|
528
|
+
v,
|
529
|
+
)
|
530
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
531
|
+
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
532
|
+
reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
|
533
|
+
o = decode_wrapper.run(
|
534
|
+
reshaped_q[:, :, : layer.v_head_dim],
|
535
|
+
reshaped_q[:, :, layer.v_head_dim :],
|
536
|
+
reshaped_k[:, :, : layer.v_head_dim],
|
537
|
+
reshaped_k[:, :, layer.v_head_dim :],
|
538
|
+
)
|
464
539
|
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
540
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
541
|
+
else:
|
542
|
+
if k is not None:
|
543
|
+
assert v is not None
|
544
|
+
if save_kv_cache:
|
545
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
546
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
547
|
+
)
|
548
|
+
|
549
|
+
o = decode_wrapper.forward(
|
550
|
+
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
551
|
+
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
552
|
+
sm_scale=layer.scaling,
|
553
|
+
logits_soft_cap=layer.logit_cap,
|
554
|
+
k_scale=layer.k_scale,
|
555
|
+
v_scale=layer.v_scale,
|
556
|
+
)
|
473
557
|
|
474
|
-
|
558
|
+
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
475
559
|
|
476
560
|
def _get_wrapper_idx(self, layer: RadixAttention):
|
477
561
|
if self.num_wrappers == 1:
|
@@ -519,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
519
603
|
req_pool_indices: torch.Tensor,
|
520
604
|
seq_lens: torch.Tensor,
|
521
605
|
seq_lens_sum: int,
|
522
|
-
decode_wrappers: List[
|
606
|
+
decode_wrappers: List[
|
607
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
608
|
+
],
|
523
609
|
encoder_lens: Optional[torch.Tensor],
|
524
610
|
spec_info: Optional[SpecInfo],
|
525
611
|
):
|
@@ -531,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
531
617
|
req_pool_indices: torch.Tensor,
|
532
618
|
seq_lens: torch.Tensor,
|
533
619
|
seq_lens_sum: int,
|
534
|
-
decode_wrappers: List[
|
620
|
+
decode_wrappers: List[
|
621
|
+
Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
|
622
|
+
],
|
535
623
|
encoder_lens: Optional[torch.Tensor],
|
536
624
|
spec_info: Optional[SpecInfo],
|
537
625
|
):
|
@@ -612,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
|
|
612
700
|
|
613
701
|
def call_begin_forward(
|
614
702
|
self,
|
615
|
-
wrapper:
|
703
|
+
wrapper: Union[
|
704
|
+
BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
|
705
|
+
],
|
616
706
|
req_pool_indices: torch.Tensor,
|
617
707
|
paged_kernel_lens: torch.Tensor,
|
618
708
|
paged_kernel_lens_sum: int,
|
@@ -640,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
|
|
640
730
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
641
731
|
bs = kv_indptr.shape[0] - 1
|
642
732
|
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
733
|
+
if global_config.enable_flashinfer_mla:
|
734
|
+
sm_scale = 1.0 / math.sqrt(192)
|
735
|
+
q_indptr = torch.arange(0, bs + 1).to(0).int()
|
736
|
+
kv_lens = paged_kernel_lens.to(torch.int32)
|
737
|
+
wrapper.plan(
|
738
|
+
q_indptr,
|
739
|
+
kv_indptr,
|
740
|
+
kv_indices,
|
741
|
+
kv_lens,
|
742
|
+
self.num_qo_heads,
|
743
|
+
512,
|
744
|
+
64,
|
745
|
+
1,
|
746
|
+
False,
|
747
|
+
sm_scale,
|
748
|
+
self.data_type,
|
749
|
+
self.data_type,
|
750
|
+
)
|
751
|
+
else:
|
752
|
+
wrapper.begin_forward(
|
753
|
+
kv_indptr,
|
754
|
+
kv_indices,
|
755
|
+
self.kv_last_page_len[:bs],
|
756
|
+
self.num_qo_heads,
|
757
|
+
self.num_kv_heads,
|
758
|
+
self.head_dim,
|
759
|
+
1,
|
760
|
+
data_type=self.data_type,
|
761
|
+
q_data_type=self.q_data_type,
|
762
|
+
non_blocking=True,
|
763
|
+
)
|
655
764
|
|
656
765
|
|
657
766
|
class FlashInferIndicesUpdaterPrefill:
|
@@ -860,31 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
|
|
860
969
|
|
861
970
|
# extend part
|
862
971
|
if use_ragged:
|
863
|
-
|
864
|
-
|
865
|
-
|
972
|
+
if global_config.enable_flashinfer_mla:
|
973
|
+
wrapper_ragged.begin_forward(
|
974
|
+
qo_indptr=qo_indptr,
|
975
|
+
kv_indptr=qo_indptr,
|
976
|
+
num_qo_heads=self.num_qo_heads,
|
977
|
+
num_kv_heads=self.num_kv_heads,
|
978
|
+
head_dim_qk=192,
|
979
|
+
head_dim_vo=128,
|
980
|
+
q_data_type=self.q_data_type,
|
981
|
+
)
|
982
|
+
else:
|
983
|
+
wrapper_ragged.begin_forward(
|
984
|
+
qo_indptr,
|
985
|
+
qo_indptr,
|
986
|
+
self.num_qo_heads,
|
987
|
+
self.num_kv_heads,
|
988
|
+
self.head_dim,
|
989
|
+
q_data_type=self.q_data_type,
|
990
|
+
)
|
991
|
+
|
992
|
+
if not global_config.enable_flashinfer_mla:
|
993
|
+
# cached part
|
994
|
+
wrapper_paged.begin_forward(
|
866
995
|
qo_indptr,
|
996
|
+
kv_indptr,
|
997
|
+
kv_indices,
|
998
|
+
self.kv_last_page_len[:bs],
|
867
999
|
self.num_qo_heads,
|
868
1000
|
self.num_kv_heads,
|
869
1001
|
self.head_dim,
|
1002
|
+
1,
|
870
1003
|
q_data_type=self.q_data_type,
|
1004
|
+
custom_mask=custom_mask,
|
1005
|
+
non_blocking=True,
|
871
1006
|
)
|
872
1007
|
|
873
|
-
# cached part
|
874
|
-
wrapper_paged.end_forward()
|
875
|
-
wrapper_paged.begin_forward(
|
876
|
-
qo_indptr,
|
877
|
-
kv_indptr,
|
878
|
-
kv_indices,
|
879
|
-
self.kv_last_page_len[:bs],
|
880
|
-
self.num_qo_heads,
|
881
|
-
self.num_kv_heads,
|
882
|
-
self.head_dim,
|
883
|
-
1,
|
884
|
-
q_data_type=self.q_data_type,
|
885
|
-
custom_mask=custom_mask,
|
886
|
-
)
|
887
|
-
|
888
1008
|
|
889
1009
|
class FlashInferMultiStepDraftBackend:
|
890
1010
|
"""
|
@@ -924,38 +1044,50 @@ class FlashInferMultiStepDraftBackend:
|
|
924
1044
|
self.max_context_len = self.attn_backends[0].max_context_len
|
925
1045
|
# Cached variables for generate_draft_decode_kv_indices
|
926
1046
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
927
|
-
self.kv_indptr_stride = self.kv_indptr.shape[1]
|
928
1047
|
|
929
|
-
def common_template(
|
1048
|
+
def common_template(
|
1049
|
+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
1050
|
+
):
|
930
1051
|
num_seqs = forward_batch.batch_size
|
931
1052
|
bs = self.topk * num_seqs
|
932
1053
|
seq_lens_sum = forward_batch.seq_lens_sum
|
1054
|
+
|
933
1055
|
self.generate_draft_decode_kv_indices[
|
934
1056
|
(self.speculative_num_steps, num_seqs, self.topk)
|
935
1057
|
](
|
936
1058
|
forward_batch.req_pool_indices,
|
937
1059
|
forward_batch.req_to_token_pool.req_to_token,
|
938
1060
|
forward_batch.seq_lens,
|
939
|
-
|
1061
|
+
kv_indices_buffer,
|
940
1062
|
self.kv_indptr,
|
941
1063
|
forward_batch.positions,
|
942
1064
|
num_seqs,
|
943
1065
|
self.topk,
|
944
1066
|
self.pool_len,
|
945
|
-
|
1067
|
+
kv_indices_buffer.shape[1],
|
946
1068
|
self.kv_indptr.shape[1],
|
947
1069
|
triton.next_power_of_2(num_seqs),
|
948
1070
|
triton.next_power_of_2(self.speculative_num_steps),
|
949
1071
|
triton.next_power_of_2(bs),
|
950
1072
|
)
|
951
|
-
|
1073
|
+
|
1074
|
+
for i in range(self.speculative_num_steps - 1):
|
952
1075
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
953
|
-
forward_batch.spec_info.kv_indices =
|
1076
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
954
1077
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
955
1078
|
]
|
956
1079
|
call_fn(i, forward_batch)
|
957
1080
|
|
958
1081
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
1082
|
+
kv_indices = torch.zeros(
|
1083
|
+
(
|
1084
|
+
self.speculative_num_steps,
|
1085
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
1086
|
+
),
|
1087
|
+
dtype=torch.int32,
|
1088
|
+
device="cuda",
|
1089
|
+
)
|
1090
|
+
|
959
1091
|
def call_fn(i, forward_batch):
|
960
1092
|
forward_batch.spec_info.kv_indptr = (
|
961
1093
|
forward_batch.spec_info.kv_indptr.clone()
|
@@ -965,7 +1097,7 @@ class FlashInferMultiStepDraftBackend:
|
|
965
1097
|
)
|
966
1098
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
967
1099
|
|
968
|
-
self.common_template(forward_batch, call_fn)
|
1100
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
969
1101
|
|
970
1102
|
def init_cuda_graph_state(self, max_bs: int):
|
971
1103
|
self.cuda_graph_kv_indices = torch.zeros(
|
@@ -973,7 +1105,6 @@ class FlashInferMultiStepDraftBackend:
|
|
973
1105
|
dtype=torch.int32,
|
974
1106
|
device="cuda",
|
975
1107
|
)
|
976
|
-
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
|
977
1108
|
for i in range(self.speculative_num_steps):
|
978
1109
|
self.attn_backends[i].init_cuda_graph_state(
|
979
1110
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
@@ -995,7 +1126,7 @@ class FlashInferMultiStepDraftBackend:
|
|
995
1126
|
][0]
|
996
1127
|
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
|
997
1128
|
|
998
|
-
self.common_template(forward_batch, call_fn)
|
1129
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
999
1130
|
|
1000
1131
|
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
1001
1132
|
def call_fn(i, forward_batch):
|
@@ -1009,7 +1140,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1009
1140
|
spec_info=forward_batch.spec_info,
|
1010
1141
|
)
|
1011
1142
|
|
1012
|
-
self.common_template(forward_batch, call_fn)
|
1143
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
1013
1144
|
|
1014
1145
|
|
1015
1146
|
@triton.jit
|
@@ -1070,21 +1201,6 @@ def should_use_tensor_core(
|
|
1070
1201
|
if env_override is not None:
|
1071
1202
|
return env_override.lower() == "true"
|
1072
1203
|
|
1073
|
-
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1074
|
-
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1075
|
-
try:
|
1076
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1077
|
-
|
1078
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
1079
|
-
num_attention_heads,
|
1080
|
-
num_kv_heads,
|
1081
|
-
):
|
1082
|
-
return True
|
1083
|
-
else:
|
1084
|
-
return False
|
1085
|
-
except (ImportError, AttributeError):
|
1086
|
-
pass
|
1087
|
-
|
1088
1204
|
# Calculate GQA group size
|
1089
1205
|
gqa_group_size = num_attention_heads // num_kv_heads
|
1090
1206
|
|
@@ -1114,6 +1230,7 @@ def fast_decode_plan(
|
|
1114
1230
|
sm_scale: Optional[float] = None,
|
1115
1231
|
rope_scale: Optional[float] = None,
|
1116
1232
|
rope_theta: Optional[float] = None,
|
1233
|
+
**kwargs,
|
1117
1234
|
) -> None:
|
1118
1235
|
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
1119
1236
|
batch_size = len(last_page_len)
|
@@ -1170,6 +1287,7 @@ def fast_decode_plan(
|
|
1170
1287
|
window_left,
|
1171
1288
|
logits_soft_cap,
|
1172
1289
|
head_dim,
|
1290
|
+
head_dim,
|
1173
1291
|
empty_q_data,
|
1174
1292
|
empty_kv_cache,
|
1175
1293
|
stream.cuda_stream,
|