sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -22,29 +22,55 @@ if TYPE_CHECKING:
|
|
22
22
|
from sglang.srt.layers.radix_attention import RadixAttention
|
23
23
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
24
24
|
|
25
|
-
from
|
25
|
+
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
26
26
|
|
27
27
|
|
28
28
|
@dataclass
|
29
29
|
class FlashAttentionMetadata:
|
30
|
-
"""Metadata
|
30
|
+
"""Metadata to be init once in the model forward pass,
|
31
|
+
each layer's forward pass can reuse the metadata."""
|
31
32
|
|
33
|
+
# Cumulative sequence lengths for query
|
32
34
|
cu_seqlens_q: torch.Tensor = None
|
35
|
+
# Cumulative sequence lengths for key
|
33
36
|
cu_seqlens_k: torch.Tensor = None
|
37
|
+
# Maximum sequence length for query
|
34
38
|
max_seq_len_q: int = 0
|
39
|
+
# Maximum sequence length for key
|
35
40
|
max_seq_len_k: int = 0
|
41
|
+
# Window size (typically used by Gemma)
|
36
42
|
window_size: tuple = (-1, -1)
|
43
|
+
# Page table, the index of KV Cache Tables/Blocks
|
37
44
|
page_table: torch.Tensor = None
|
45
|
+
# Sequence lengths for the forward batch
|
38
46
|
cache_seqlens_int32: torch.Tensor = None
|
39
47
|
|
40
48
|
|
41
49
|
class FlashAttentionBackend(AttentionBackend):
|
42
|
-
"""FlashAttention backend implementation.
|
50
|
+
"""FlashAttention backend implementation.
|
51
|
+
|
52
|
+
Note about the init:
|
53
|
+
- If no spec decoding
|
54
|
+
- FlashAttentionBackend will be init once when the server starts.
|
55
|
+
- If spec decoding
|
56
|
+
- FlashAttentionBackend will be init once for the target worker
|
57
|
+
- FlashAttentionMultiStepBackend will be once for the draft worker
|
58
|
+
- It will spawn num_steps FlashAttentionBackend for the draft worker
|
59
|
+
|
60
|
+
Note about CUDA Graph:
|
61
|
+
- We only support CUDA Graph for Decode (Normal Decode and Draft Decode) and Target Verify.
|
62
|
+
- We don't support CUDA Graph for Extend and Draft Extend.
|
63
|
+
- When server init, init_cuda_graph_state will be called first and then init_cuda_graph_capture will be called.
|
64
|
+
- For each forward batch, init_replay_cuda_graph will be called first and then replay the graph.
|
65
|
+
"""
|
43
66
|
|
44
67
|
def __init__(
|
45
68
|
self,
|
46
69
|
model_runner: ModelRunner,
|
47
70
|
skip_prefill: bool = False,
|
71
|
+
topk=0,
|
72
|
+
speculative_num_steps=0,
|
73
|
+
step_id=0,
|
48
74
|
):
|
49
75
|
super().__init__()
|
50
76
|
|
@@ -53,56 +79,121 @@ class FlashAttentionBackend(AttentionBackend):
|
|
53
79
|
and model_runner.model_config.is_encoder_decoder
|
54
80
|
), "Sliding window and cross attention are not supported together"
|
55
81
|
|
56
|
-
# Initialize metadata
|
57
82
|
self.forward_metadata: FlashAttentionMetadata = None
|
58
83
|
self.max_context_len = model_runner.model_config.context_len
|
59
84
|
self.device = model_runner.device
|
60
85
|
self.decode_cuda_graph_metadata = {}
|
86
|
+
self.target_verify_metadata = {}
|
61
87
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
62
88
|
self.page_size = model_runner.page_size
|
63
89
|
self.use_mla = (
|
64
90
|
model_runner.model_config.attention_arch == AttentionArch.MLA
|
65
91
|
) and (not global_server_args_dict["disable_mla"])
|
92
|
+
self.skip_prefill = skip_prefill
|
93
|
+
|
94
|
+
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
95
|
+
assert (
|
96
|
+
topk <= 1
|
97
|
+
), "topk must be 1 (if spec decoding) or 0 (if no spec decoding) for FlashAttentionBackend"
|
98
|
+
|
99
|
+
self.topk = 1
|
100
|
+
self.step_id = step_id
|
101
|
+
self.speculative_num_steps = speculative_num_steps
|
66
102
|
|
67
103
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
68
104
|
"""Initialize forward metadata to cache repetitive calculations."""
|
69
|
-
# Create metadata based on forward mode
|
70
105
|
metadata = FlashAttentionMetadata()
|
71
|
-
|
72
|
-
# Get sequence information
|
73
106
|
seqlens_in_batch = forward_batch.seq_lens
|
74
|
-
# Precompute int32 version of sequence lengths
|
75
|
-
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
76
107
|
batch_size = len(seqlens_in_batch)
|
77
108
|
device = seqlens_in_batch.device
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
109
|
+
if forward_batch.forward_mode.is_decode():
|
110
|
+
# Skip Prefill or Draft Decode
|
111
|
+
# Note: Draft Decode will be ran on the Draft Worker
|
112
|
+
if forward_batch.spec_info is not None:
|
113
|
+
metadata.cu_seqlens_q = torch.arange(
|
114
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
115
|
+
)
|
116
|
+
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
|
117
|
+
metadata.cache_seqlens_int32 = seq_lens_with_decode.to(torch.int32)
|
118
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
119
|
+
torch.cumsum(
|
120
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
121
|
+
),
|
122
|
+
(1, 0),
|
123
|
+
)
|
124
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
|
125
|
+
self.step_id + 1
|
126
|
+
)
|
127
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
128
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
129
|
+
]
|
130
|
+
cache_loc = forward_batch.out_cache_loc.view(
|
131
|
+
self.speculative_num_steps, -1
|
132
|
+
).T
|
133
|
+
|
134
|
+
for idx, single_seq_len in enumerate(seq_lens_with_decode):
|
135
|
+
real_bsz_start_idx = idx
|
136
|
+
real_bsz_end_idx = idx + 1
|
137
|
+
metadata.page_table[
|
138
|
+
real_bsz_start_idx:real_bsz_end_idx,
|
139
|
+
(single_seq_len - (self.step_id + 1)) : single_seq_len,
|
140
|
+
] = cache_loc[
|
141
|
+
real_bsz_start_idx:real_bsz_end_idx, : (self.step_id + 1)
|
142
|
+
]
|
143
|
+
else: # Normal Decode without Spec Decoding
|
144
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
145
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
146
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
147
|
+
)
|
148
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
149
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
150
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
151
|
+
]
|
152
|
+
metadata.cu_seqlens_q = torch.arange(
|
153
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
154
|
+
)
|
155
|
+
elif forward_batch.forward_mode.is_target_verify():
|
156
|
+
# Note: Target Verify will be ran on the Target Worker
|
157
|
+
draft_token_num = forward_batch.spec_info.draft_token_num
|
158
|
+
metadata.cache_seqlens_int32 = (
|
159
|
+
forward_batch.seq_lens + draft_token_num
|
160
|
+
).to(torch.int32)
|
161
|
+
metadata.max_seq_len_q = draft_token_num
|
162
|
+
metadata.max_seq_len_k = (
|
163
|
+
forward_batch.seq_lens_cpu.max().item() + draft_token_num
|
96
164
|
)
|
97
|
-
|
98
|
-
if forward_batch.forward_mode == ForwardMode.DECODE:
|
99
|
-
# Precompute cumulative sequence lengths
|
100
165
|
metadata.cu_seqlens_q = torch.arange(
|
101
|
-
0,
|
166
|
+
0,
|
167
|
+
batch_size * draft_token_num + 1,
|
168
|
+
draft_token_num,
|
169
|
+
dtype=torch.int32,
|
170
|
+
device=device,
|
102
171
|
)
|
103
|
-
|
172
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
173
|
+
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
174
|
+
(1, 0),
|
175
|
+
)
|
176
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
177
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
178
|
+
]
|
179
|
+
|
180
|
+
elif forward_batch.forward_mode.is_extend_or_draft_extend():
|
181
|
+
# Normal or Draft Extend (Both of them will be ran on the Target Worker)
|
182
|
+
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
183
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
184
|
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
185
|
+
)
|
186
|
+
# Precompute maximum sequence length
|
187
|
+
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
188
|
+
# Precompute page table
|
189
|
+
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
190
|
+
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
191
|
+
]
|
104
192
|
# Precompute cumulative sequence lengths
|
105
|
-
if
|
193
|
+
if (
|
194
|
+
any(forward_batch.extend_prefix_lens_cpu)
|
195
|
+
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
|
196
|
+
):
|
106
197
|
extend_seq_lens = forward_batch.extend_seq_lens
|
107
198
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
108
199
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
@@ -111,6 +202,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
111
202
|
else:
|
112
203
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
113
204
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
205
|
+
|
206
|
+
# Precompute strided indices
|
207
|
+
if self.page_size > 1:
|
208
|
+
self.strided_indices = torch.arange(
|
209
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
210
|
+
)
|
211
|
+
metadata.page_table = (
|
212
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
213
|
+
)
|
114
214
|
self.forward_metadata = metadata
|
115
215
|
|
116
216
|
def forward_extend(
|
@@ -122,7 +222,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
122
222
|
forward_batch: ForwardBatch,
|
123
223
|
save_kv_cache=True,
|
124
224
|
):
|
125
|
-
|
126
225
|
if k is not None:
|
127
226
|
assert v is not None
|
128
227
|
if save_kv_cache:
|
@@ -157,7 +256,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
157
256
|
|
158
257
|
page_table = metadata.page_table
|
159
258
|
|
160
|
-
#
|
259
|
+
# Use Flash Attention for prefill
|
161
260
|
if not self.use_mla:
|
162
261
|
# Do multi-head attention
|
163
262
|
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
@@ -263,7 +362,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
263
362
|
if layer.sliding_window_size is not None
|
264
363
|
else (-1, -1)
|
265
364
|
)
|
266
|
-
|
267
365
|
page_table = metadata.page_table
|
268
366
|
|
269
367
|
if not self.use_mla:
|
@@ -281,8 +379,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
281
379
|
|
282
380
|
# Pre-reshape query tensor
|
283
381
|
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
284
|
-
|
285
|
-
# Run attention with precomputed values
|
286
382
|
o = flash_attn_with_kvcache(
|
287
383
|
q=q_reshaped,
|
288
384
|
k_cache=key_cache,
|
@@ -334,7 +430,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
334
430
|
k_descale=layer.k_scale,
|
335
431
|
v_descale=layer.v_scale,
|
336
432
|
)
|
337
|
-
|
338
433
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
339
434
|
|
340
435
|
def init_cuda_graph_state(self, max_bs: int):
|
@@ -346,7 +441,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
346
441
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
347
442
|
to avoid memory allocations.
|
348
443
|
"""
|
349
|
-
# Initialize fixed size tensors for decode operations
|
350
444
|
self.decode_cuda_graph_metadata = {
|
351
445
|
# Page table for token mapping (batch_size, max_context_len)
|
352
446
|
"page_table": torch.zeros(
|
@@ -355,6 +449,39 @@ class FlashAttentionBackend(AttentionBackend):
|
|
355
449
|
dtype=torch.int32,
|
356
450
|
device=self.device,
|
357
451
|
),
|
452
|
+
"page_table_draft_decode": torch.zeros(
|
453
|
+
max_bs,
|
454
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
455
|
+
dtype=torch.int32,
|
456
|
+
device=self.device,
|
457
|
+
),
|
458
|
+
"strided_indices": torch.arange(
|
459
|
+
0, self.max_context_len, self.page_size, device=self.device
|
460
|
+
),
|
461
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
462
|
+
"cu_seqlens_q": torch.arange(
|
463
|
+
0, max_bs + 128, dtype=torch.int32, device=self.device
|
464
|
+
),
|
465
|
+
"cu_seqlens_k": torch.zeros(
|
466
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
467
|
+
),
|
468
|
+
}
|
469
|
+
|
470
|
+
self.target_verify_metadata = {
|
471
|
+
"page_table": torch.zeros(
|
472
|
+
max_bs,
|
473
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
474
|
+
dtype=torch.int32,
|
475
|
+
device=self.device,
|
476
|
+
),
|
477
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
478
|
+
"cu_seqlens_q": torch.zeros(
|
479
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
480
|
+
),
|
481
|
+
"cu_seqlens_k": torch.zeros(
|
482
|
+
max_bs + 128, dtype=torch.int32, device=self.device
|
483
|
+
),
|
484
|
+
"max_seqlen_q": 0,
|
358
485
|
"strided_indices": torch.arange(
|
359
486
|
0, self.max_context_len, self.page_size, device=self.device
|
360
487
|
),
|
@@ -372,27 +499,89 @@ class FlashAttentionBackend(AttentionBackend):
|
|
372
499
|
):
|
373
500
|
"""Initialize forward metadata for capturing CUDA graph."""
|
374
501
|
metadata = FlashAttentionMetadata()
|
375
|
-
# Get sequence information
|
376
|
-
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
377
|
-
batch_size = len(seq_lens)
|
378
502
|
device = seq_lens.device
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
503
|
+
if forward_mode.is_decode():
|
504
|
+
if spec_info is not None:
|
505
|
+
# Draft Decode
|
506
|
+
metadata.cu_seqlens_q = torch.arange(
|
507
|
+
0, bs + 1, dtype=torch.int32, device=device
|
508
|
+
)
|
509
|
+
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
510
|
+
"cache_seqlens"
|
511
|
+
][:bs]
|
512
|
+
|
513
|
+
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
514
|
+
: bs + 1
|
515
|
+
]
|
516
|
+
|
517
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
518
|
+
torch.cumsum(
|
519
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
520
|
+
),
|
521
|
+
(1, 0),
|
522
|
+
)
|
523
|
+
metadata.max_seq_len_k = seq_lens.max().item() + (self.step_id + 1)
|
524
|
+
metadata.page_table = self.decode_cuda_graph_metadata[
|
525
|
+
"page_table_draft_decode"
|
526
|
+
][req_pool_indices, :]
|
527
|
+
else:
|
528
|
+
# Normal Decode
|
529
|
+
# Get sequence information
|
530
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
531
|
+
batch_size = len(seq_lens)
|
532
|
+
device = seq_lens.device
|
533
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
534
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
535
|
+
)
|
536
|
+
# Precompute maximum sequence length
|
537
|
+
metadata.max_seq_len_k = seq_lens.max().item()
|
538
|
+
# Precompute page table
|
539
|
+
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
540
|
+
req_pool_indices, :
|
541
|
+
]
|
542
|
+
# Precompute cumulative sequence lengths
|
543
|
+
metadata.cu_seqlens_q = torch.arange(
|
544
|
+
0, batch_size + 1, dtype=torch.int32, device=device
|
545
|
+
)
|
546
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
547
|
+
elif forward_mode.is_target_verify():
|
548
|
+
draft_token_num = spec_info.draft_token_num
|
549
|
+
|
550
|
+
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
551
|
+
:bs
|
552
|
+
]
|
553
|
+
metadata.cache_seqlens_int32.copy_(
|
554
|
+
(seq_lens + draft_token_num).to(torch.int32)
|
392
555
|
)
|
393
|
-
|
394
|
-
|
395
|
-
|
556
|
+
|
557
|
+
metadata.max_seq_len_q = draft_token_num
|
558
|
+
metadata.max_seq_len_k = seq_lens.max().item() + draft_token_num
|
559
|
+
|
560
|
+
metadata.cu_seqlens_q = self.target_verify_metadata["cu_seqlens_q"][
|
561
|
+
torch.arange(
|
562
|
+
0,
|
563
|
+
bs * draft_token_num + 1,
|
564
|
+
draft_token_num,
|
565
|
+
dtype=torch.int32,
|
566
|
+
device=device,
|
567
|
+
)
|
568
|
+
]
|
569
|
+
cu_k = self.target_verify_metadata["cu_seqlens_k"][: (bs + 1)]
|
570
|
+
cu_k.copy_(
|
571
|
+
torch.nn.functional.pad(
|
572
|
+
torch.cumsum(
|
573
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
574
|
+
),
|
575
|
+
(1, 0),
|
576
|
+
)
|
577
|
+
)
|
578
|
+
metadata.cu_seqlens_k = cu_k
|
579
|
+
metadata.page_table = self.target_verify_metadata["page_table"][
|
580
|
+
req_pool_indices, :
|
581
|
+
]
|
582
|
+
|
583
|
+
self.target_verify_metadata[bs] = metadata
|
584
|
+
|
396
585
|
self.forward_metadata = metadata
|
397
586
|
|
398
587
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -405,30 +594,159 @@ class FlashAttentionBackend(AttentionBackend):
|
|
405
594
|
forward_mode: ForwardMode,
|
406
595
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
407
596
|
seq_lens_cpu: Optional[torch.Tensor],
|
597
|
+
out_cache_loc: torch.Tensor = None,
|
408
598
|
):
|
409
599
|
# """Initialize forward metadata for replaying CUDA graph."""
|
410
|
-
|
600
|
+
device = seq_lens.device
|
601
|
+
seq_lens = seq_lens[:bs]
|
602
|
+
req_pool_indices = req_pool_indices[:bs]
|
603
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
604
|
+
if forward_mode.is_decode():
|
605
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
606
|
+
|
607
|
+
if spec_info is not None:
|
608
|
+
# Draft Decode
|
609
|
+
max_len = seq_lens_cpu.max().item()
|
610
|
+
metadata.max_seq_len_k = max_len + (self.step_id + 1)
|
611
|
+
|
612
|
+
metadata.cache_seqlens_int32.copy_(
|
613
|
+
(seq_lens + (self.step_id + 1)).to(torch.int32)
|
614
|
+
)
|
411
615
|
|
412
|
-
|
413
|
-
max_len = seq_lens_cpu[:bs].max().item()
|
414
|
-
metadata.max_seq_len_k = max_len
|
616
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (self.step_id + 1)
|
415
617
|
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
618
|
+
metadata.cu_seqlens_k.copy_(
|
619
|
+
torch.nn.functional.pad(
|
620
|
+
torch.cumsum(
|
621
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
622
|
+
),
|
623
|
+
(1, 0),
|
624
|
+
)
|
625
|
+
)
|
626
|
+
|
627
|
+
page_table = self.req_to_token[
|
628
|
+
req_pool_indices, : metadata.max_seq_len_k
|
629
|
+
]
|
630
|
+
|
631
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
632
|
+
else:
|
633
|
+
# Normal Decode
|
634
|
+
max_len = seq_lens_cpu.max().item()
|
635
|
+
metadata.max_seq_len_k = max_len
|
636
|
+
|
637
|
+
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
638
|
+
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
639
|
+
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
640
|
+
)
|
641
|
+
|
642
|
+
max_seq_pages = (
|
643
|
+
metadata.max_seq_len_k + self.page_size - 1
|
644
|
+
) // self.page_size
|
645
|
+
page_indices = self.req_to_token[
|
646
|
+
:,
|
647
|
+
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
648
|
+
]
|
649
|
+
page_indices = page_indices[req_pool_indices] // self.page_size
|
650
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
651
|
+
metadata.page_table[:, max_seq_pages:].fill_(0)
|
652
|
+
|
653
|
+
elif forward_mode.is_target_verify():
|
654
|
+
metadata = self.target_verify_metadata[bs]
|
655
|
+
draft_token_num = spec_info.draft_token_num
|
656
|
+
|
657
|
+
metadata.cu_seqlens_q.copy_(
|
658
|
+
torch.arange(
|
659
|
+
0,
|
660
|
+
bs * draft_token_num + 1,
|
661
|
+
draft_token_num,
|
662
|
+
dtype=torch.int32,
|
663
|
+
device=device,
|
664
|
+
)
|
665
|
+
)
|
666
|
+
metadata.cache_seqlens_int32.copy_(
|
667
|
+
(seq_lens + draft_token_num).to(torch.int32)
|
668
|
+
)
|
669
|
+
|
670
|
+
metadata.max_seq_len_k = seq_lens_cpu.max().item() + draft_token_num
|
671
|
+
metadata.cu_seqlens_k.copy_(
|
672
|
+
torch.nn.functional.pad(
|
673
|
+
torch.cumsum(
|
674
|
+
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
675
|
+
),
|
676
|
+
(1, 0),
|
677
|
+
)
|
678
|
+
)
|
679
|
+
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
680
|
+
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
422
681
|
|
423
|
-
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
424
|
-
page_indices = self.req_to_token[
|
425
|
-
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
426
|
-
]
|
427
|
-
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
428
|
-
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
429
|
-
metadata.page_table[:, max_seq_pages:].fill_(0)
|
430
682
|
self.forward_metadata = metadata
|
431
683
|
|
432
684
|
def get_cuda_graph_seq_len_fill_value(self):
|
433
685
|
"""Get the fill value for sequence length in CUDA graph."""
|
434
686
|
return 0
|
687
|
+
|
688
|
+
|
689
|
+
class FlashAttentionMultiStepBackend:
|
690
|
+
|
691
|
+
def __init__(
|
692
|
+
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
|
693
|
+
):
|
694
|
+
self.model_runner = model_runner
|
695
|
+
self.topk = topk
|
696
|
+
self.speculative_num_steps = speculative_num_steps
|
697
|
+
|
698
|
+
self.attn_backends = []
|
699
|
+
for i in range(self.speculative_num_steps):
|
700
|
+
self.attn_backends.append(
|
701
|
+
FlashAttentionBackend(
|
702
|
+
model_runner,
|
703
|
+
topk=self.topk,
|
704
|
+
speculative_num_steps=self.speculative_num_steps,
|
705
|
+
step_id=i,
|
706
|
+
)
|
707
|
+
)
|
708
|
+
|
709
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
710
|
+
for i in range(self.speculative_num_steps - 1):
|
711
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
712
|
+
|
713
|
+
def init_cuda_graph_state(self, max_bs: int):
|
714
|
+
for i in range(self.speculative_num_steps):
|
715
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
716
|
+
|
717
|
+
def init_forward_metadata_capture_cuda_graph(
|
718
|
+
self,
|
719
|
+
forward_batch: ForwardBatch,
|
720
|
+
):
|
721
|
+
assert forward_batch.spec_info is not None
|
722
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
723
|
+
|
724
|
+
for i in range(self.speculative_num_steps - 1):
|
725
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
726
|
+
forward_batch.batch_size,
|
727
|
+
forward_batch.batch_size * self.topk,
|
728
|
+
forward_batch.req_pool_indices,
|
729
|
+
forward_batch.seq_lens,
|
730
|
+
encoder_lens=None,
|
731
|
+
forward_mode=ForwardMode.DECODE,
|
732
|
+
spec_info=forward_batch.spec_info,
|
733
|
+
)
|
734
|
+
|
735
|
+
def init_forward_metadata_replay_cuda_graph(
|
736
|
+
self, forward_batch: ForwardBatch, bs: int
|
737
|
+
):
|
738
|
+
assert forward_batch.spec_info is not None
|
739
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
740
|
+
|
741
|
+
for i in range(self.speculative_num_steps - 1):
|
742
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
743
|
+
bs,
|
744
|
+
forward_batch.req_pool_indices,
|
745
|
+
forward_batch.seq_lens,
|
746
|
+
forward_batch.seq_lens_sum,
|
747
|
+
encoder_lens=None,
|
748
|
+
forward_mode=ForwardMode.DECODE,
|
749
|
+
spec_info=forward_batch.spec_info,
|
750
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
751
|
+
out_cache_loc=forward_batch.out_cache_loc,
|
752
|
+
)
|
@@ -14,7 +14,6 @@ from functools import partial
|
|
14
14
|
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
15
15
|
|
16
16
|
import torch
|
17
|
-
import triton
|
18
17
|
|
19
18
|
from sglang.global_config import global_config
|
20
19
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -22,7 +21,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito
|
|
22
21
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
23
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
23
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
25
|
-
from sglang.srt.utils import
|
24
|
+
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
26
25
|
|
27
26
|
if TYPE_CHECKING:
|
28
27
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -932,6 +931,7 @@ class FlashInferMultiStepDraftBackend:
|
|
932
931
|
self.topk = topk
|
933
932
|
self.speculative_num_steps = speculative_num_steps
|
934
933
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
934
|
+
self.page_size = model_runner.page_size
|
935
935
|
|
936
936
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
937
937
|
self.kv_indptr = torch.zeros(
|
@@ -985,9 +985,9 @@ class FlashInferMultiStepDraftBackend:
|
|
985
985
|
self.pool_len,
|
986
986
|
kv_indices_buffer.shape[1],
|
987
987
|
self.kv_indptr.shape[1],
|
988
|
-
|
989
|
-
|
990
|
-
|
988
|
+
next_power_of_2(num_seqs),
|
989
|
+
next_power_of_2(self.speculative_num_steps),
|
990
|
+
next_power_of_2(bs),
|
991
991
|
)
|
992
992
|
|
993
993
|
assert forward_batch.spec_info is not None
|
@@ -1018,8 +1018,6 @@ class FlashInferMultiStepDraftBackend:
|
|
1018
1018
|
)
|
1019
1019
|
|
1020
1020
|
def call_fn(i, forward_batch):
|
1021
|
-
assert forward_batch.spec_info is not None
|
1022
|
-
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
1023
1021
|
forward_batch.spec_info.kv_indptr = (
|
1024
1022
|
forward_batch.spec_info.kv_indptr.clone()
|
1025
1023
|
)
|
@@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
71
71
|
self.device = model_runner.device
|
72
72
|
self.skip_prefill = skip_prefill
|
73
73
|
|
74
|
-
global_config.enable_flashinfer_mla = True
|
75
|
-
|
76
74
|
# Allocate buffers
|
77
75
|
global global_workspace_buffer
|
78
76
|
if global_workspace_buffer is None:
|
@@ -797,7 +795,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
797
795
|
encoder_lens=None,
|
798
796
|
forward_mode=ForwardMode.DECODE,
|
799
797
|
spec_info=forward_batch.spec_info,
|
800
|
-
seq_lens_cpu=forward_batch.
|
798
|
+
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
801
799
|
)
|
802
800
|
|
803
801
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|
92
92
|
if forward_batch.forward_mode.is_decode_or_idle():
|
93
93
|
if spec_info is None:
|
94
94
|
max_seqlen_pad = triton.cdiv(
|
95
|
-
forward_batch.
|
95
|
+
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
96
96
|
)
|
97
97
|
block_kv_indices = torch.full(
|
98
98
|
(bs, max_seqlen_pad),
|