sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,6 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
|
14
|
-
from sglang.global_config import global_config
|
15
|
-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
16
14
|
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
17
15
|
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
18
16
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
@@ -22,7 +20,6 @@ from sglang.srt.utils import is_cuda
|
|
22
20
|
if TYPE_CHECKING:
|
23
21
|
from sglang.srt.layers.radix_attention import RadixAttention
|
24
22
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
23
|
from sglang.srt.speculative.spec_info import SpecInfo
|
27
24
|
|
28
25
|
_is_cuda = is_cuda()
|
@@ -108,7 +105,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
108
105
|
PAGE_SIZE,
|
109
106
|
)
|
110
107
|
workspace_size = cutlass_mla_get_workspace_size(
|
111
|
-
max_seqlen_pad * PAGE_SIZE, bs
|
108
|
+
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
112
109
|
)
|
113
110
|
workspace = torch.empty(
|
114
111
|
workspace_size, device="cuda", dtype=torch.uint8
|
@@ -125,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
125
122
|
def init_cuda_graph_state(
|
126
123
|
self,
|
127
124
|
max_bs: int,
|
125
|
+
max_num_tokens: int,
|
128
126
|
block_kv_indices: Optional[torch.Tensor] = None,
|
129
127
|
):
|
130
128
|
if block_kv_indices is None:
|
@@ -138,7 +136,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
138
136
|
cuda_graph_kv_indices = block_kv_indices
|
139
137
|
|
140
138
|
workspace_size = cutlass_mla_get_workspace_size(
|
141
|
-
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
139
|
+
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
|
142
140
|
)
|
143
141
|
self.cuda_graph_mla_workspace = torch.empty(
|
144
142
|
workspace_size, device="cuda", dtype=torch.uint8
|
@@ -233,29 +231,55 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
233
231
|
layer: RadixAttention,
|
234
232
|
forward_batch: ForwardBatch,
|
235
233
|
save_kv_cache: bool = True,
|
234
|
+
# For multi-head latent attention
|
235
|
+
q_rope: Optional[torch.Tensor] = None,
|
236
|
+
k_rope: Optional[torch.Tensor] = None,
|
236
237
|
):
|
237
238
|
cache_loc = forward_batch.out_cache_loc
|
238
239
|
|
239
240
|
if k is not None:
|
240
241
|
assert v is not None
|
241
242
|
if save_kv_cache:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
243
|
+
if k_rope is not None:
|
244
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
245
|
+
layer,
|
246
|
+
cache_loc,
|
247
|
+
k,
|
248
|
+
k_rope,
|
249
|
+
)
|
250
|
+
else:
|
251
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
252
|
+
layer,
|
253
|
+
cache_loc,
|
254
|
+
k,
|
255
|
+
v,
|
256
|
+
)
|
257
|
+
|
258
|
+
# Reshape inputs
|
259
|
+
if q_rope is not None:
|
260
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
261
|
+
q_rope = q_rope.view(
|
262
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
266
|
+
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
267
|
+
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
250
268
|
|
251
|
-
|
269
|
+
q_nope = q_nope.to(self.q_data_type)
|
270
|
+
q_rope = q_rope.to(self.q_data_type)
|
271
|
+
|
272
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
252
273
|
|
253
274
|
o = cutlass_mla_decode(
|
254
|
-
|
275
|
+
q_nope=q_nope,
|
276
|
+
q_pe=q_rope,
|
255
277
|
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
256
278
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
257
279
|
page_table=self.forward_metadata.block_kv_indices,
|
258
280
|
workspace=self.forward_metadata.workspace,
|
281
|
+
sm_scale=layer.scaling,
|
282
|
+
num_kv_splits=1,
|
259
283
|
)
|
260
284
|
|
261
285
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
@@ -11,7 +11,6 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
12
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
13
13
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
14
|
-
from sglang.srt.utils import get_compiler_backend
|
15
14
|
|
16
15
|
if TYPE_CHECKING:
|
17
16
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -394,7 +393,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
394
393
|
dtype=torch.int32,
|
395
394
|
)
|
396
395
|
metadata_expand.max_seq_len_q = 1
|
397
|
-
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
398
396
|
metadata_expand.cu_seqlens_q = torch.arange(
|
399
397
|
0,
|
400
398
|
metadata_expand.cache_seqlens_int32.numel() + 1,
|
@@ -408,9 +406,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
408
406
|
dtype=torch.int32,
|
409
407
|
device=device,
|
410
408
|
)
|
409
|
+
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
411
410
|
cache_loc = forward_batch.out_cache_loc.view(
|
412
|
-
self.speculative_num_steps
|
413
|
-
)
|
411
|
+
-1, self.speculative_num_steps
|
412
|
+
)
|
414
413
|
metadata_expand.page_table = (
|
415
414
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
416
415
|
)
|
@@ -550,9 +549,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
550
549
|
),
|
551
550
|
(1, 0),
|
552
551
|
)
|
553
|
-
metadata_expand.max_seq_len_k = (
|
554
|
-
metadata_expand.cache_seqlens_int32.max().item()
|
555
|
-
)
|
556
552
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
557
553
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
558
554
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
@@ -1124,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1124
1120
|
|
1125
1121
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
1126
1122
|
|
1127
|
-
def init_cuda_graph_state(self, max_bs: int):
|
1123
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1128
1124
|
"""Initialize CUDA graph state for the attention backend.
|
1129
1125
|
|
1130
1126
|
Args:
|
@@ -1421,9 +1417,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1421
1417
|
]
|
1422
1418
|
)
|
1423
1419
|
metadata_expand.max_seq_len_q = 1
|
1424
|
-
metadata_expand.max_seq_len_k = (
|
1425
|
-
self.speculative_step_id + 1
|
1426
|
-
) # , do this in replay
|
1427
1420
|
metadata_expand.cu_seqlens_q = (
|
1428
1421
|
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
1429
1422
|
: bs * self.topk + 1
|
@@ -1469,7 +1462,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1469
1462
|
"cache_seqlens"
|
1470
1463
|
][:bs]
|
1471
1464
|
metadata.cache_seqlens_int32.copy_(
|
1472
|
-
(seq_lens + self.speculative_num_draft_tokens)
|
1465
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
1473
1466
|
)
|
1474
1467
|
|
1475
1468
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
@@ -1536,7 +1529,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1536
1529
|
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
1537
1530
|
:bs
|
1538
1531
|
]
|
1539
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1532
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1540
1533
|
|
1541
1534
|
num_tokens_per_bs = num_tokens // bs
|
1542
1535
|
metadata.max_seq_len_q = num_tokens_per_bs
|
@@ -1600,38 +1593,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1600
1593
|
if spec_info is not None:
|
1601
1594
|
# Draft Decode
|
1602
1595
|
if self.topk <= 1:
|
1603
|
-
metadata = self.decode_cuda_graph_metadata[bs]
|
1604
1596
|
# When topk = 1, we use the normal decode metadata
|
1605
|
-
metadata.
|
1606
|
-
|
1607
|
-
|
1608
|
-
|
1609
|
-
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
1610
|
-
self.speculative_step_id + 1
|
1611
|
-
)
|
1612
|
-
metadata.cu_seqlens_k[1:].copy_(
|
1613
|
-
torch.cumsum(
|
1614
|
-
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
1615
|
-
)
|
1616
|
-
)
|
1617
|
-
|
1597
|
+
metadata = self.decode_cuda_graph_metadata[bs]
|
1598
|
+
max_len = seq_lens_cpu.max().item()
|
1599
|
+
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
1618
1600
|
max_seq_pages = (
|
1619
1601
|
metadata.max_seq_len_k + self.page_size - 1
|
1620
1602
|
) // self.page_size
|
1621
|
-
page_indices = self.req_to_token[
|
1622
|
-
req_pool_indices[:, None],
|
1623
|
-
self.decode_cuda_graph_metadata["strided_indices"][
|
1624
|
-
:max_seq_pages
|
1625
|
-
],
|
1626
|
-
]
|
1627
1603
|
|
1628
|
-
|
1629
|
-
|
1604
|
+
normal_decode_set_medadata(
|
1605
|
+
metadata.cache_seqlens_int32,
|
1606
|
+
metadata.cu_seqlens_k,
|
1607
|
+
metadata.page_table,
|
1608
|
+
self.req_to_token,
|
1609
|
+
req_pool_indices,
|
1610
|
+
self.decode_cuda_graph_metadata["strided_indices"],
|
1611
|
+
max_seq_pages,
|
1612
|
+
seq_lens,
|
1613
|
+
self.speculative_step_id + 1,
|
1614
|
+
self.page_size,
|
1615
|
+
)
|
1616
|
+
|
1630
1617
|
else:
|
1631
1618
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
1632
1619
|
# 1. The first half of metadata for prefix tokens
|
1633
1620
|
metadata = self.draft_decode_metadata_topk_normal[bs]
|
1634
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1621
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1635
1622
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
1636
1623
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1637
1624
|
# metadata.cu_seqlens_q already set in capture
|
@@ -1650,11 +1637,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1650
1637
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1651
1638
|
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
1652
1639
|
decode_length = self.speculative_step_id + 1
|
1653
|
-
|
1654
|
-
|
1655
|
-
).T.contiguous()
|
1640
|
+
# shape: [bs, num_steps, topk] -> [bs x topk, num_steps]
|
1641
|
+
cache_loc = out_cache_loc.view(-1, self.speculative_num_steps)
|
1656
1642
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
1657
|
-
cache_loc[:, :decode_length]
|
1643
|
+
cache_loc[:, :decode_length]
|
1658
1644
|
)
|
1659
1645
|
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
1660
1646
|
else:
|
@@ -1665,12 +1651,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1665
1651
|
metadata.max_seq_len_k = max_len
|
1666
1652
|
|
1667
1653
|
normal_decode_set_medadata(
|
1668
|
-
metadata,
|
1654
|
+
metadata.cache_seqlens_int32,
|
1655
|
+
metadata.cu_seqlens_k,
|
1656
|
+
metadata.page_table,
|
1669
1657
|
self.req_to_token,
|
1670
1658
|
req_pool_indices,
|
1671
1659
|
self.decode_cuda_graph_metadata["strided_indices"],
|
1672
1660
|
max_seq_pages,
|
1673
1661
|
seq_lens,
|
1662
|
+
0,
|
1674
1663
|
self.page_size,
|
1675
1664
|
)
|
1676
1665
|
|
@@ -1679,7 +1668,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1679
1668
|
if self.topk <= 1:
|
1680
1669
|
metadata = self.target_verify_metadata[bs]
|
1681
1670
|
metadata.cache_seqlens_int32.copy_(
|
1682
|
-
(seq_lens + self.speculative_num_draft_tokens)
|
1671
|
+
(seq_lens + self.speculative_num_draft_tokens)
|
1683
1672
|
)
|
1684
1673
|
|
1685
1674
|
metadata.max_seq_len_k = (
|
@@ -1701,7 +1690,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1701
1690
|
# When topk > 1, we need two specific target verify metadata, and then merge states
|
1702
1691
|
# 1. The first half of metadata for prefix tokens
|
1703
1692
|
metadata = self.target_verify_metadata_topk_normal[bs]
|
1704
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1693
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1705
1694
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
1706
1695
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1707
1696
|
# metadata.cu_seqlens_q already set in capture
|
@@ -1715,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1715
1704
|
|
1716
1705
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1717
1706
|
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
1707
|
+
|
1718
1708
|
# metadata_expand.max_seq_len_q = 1, already set in capture
|
1719
1709
|
# metadata_expand.cu_seqlens_q already set in capture
|
1720
|
-
|
1721
1710
|
offsets = torch.arange(
|
1722
1711
|
self.speculative_num_draft_tokens, device=device
|
1723
1712
|
).unsqueeze(
|
1724
1713
|
0
|
1725
1714
|
) # shape: (1, self.speculative_num_draft_tokens)
|
1715
|
+
|
1726
1716
|
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
1727
1717
|
cum_len = torch.nn.functional.pad(
|
1728
1718
|
torch.cumsum(
|
@@ -1739,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1739
1729
|
).view(1, -1)
|
1740
1730
|
# avoid extracting padded seq indices which will be out of boundary
|
1741
1731
|
mask_extraction_indices[
|
1742
|
-
:,
|
1732
|
+
:,
|
1733
|
+
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
|
1743
1734
|
].fill_(0)
|
1744
|
-
|
1745
1735
|
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
1746
1736
|
-1, self.speculative_num_draft_tokens
|
1747
1737
|
) # (bsz * draft_num, draft_num)
|
1738
|
+
|
1748
1739
|
col_indices = offsets.expand(
|
1749
1740
|
mask.shape[0], self.speculative_num_draft_tokens
|
1750
1741
|
)
|
1751
1742
|
keys = torch.where(
|
1752
|
-
mask,
|
1743
|
+
mask,
|
1744
|
+
col_indices,
|
1745
|
+
col_indices + self.speculative_num_draft_tokens,
|
1753
1746
|
)
|
1754
1747
|
_, sort_order = torch.sort(keys, dim=1)
|
1755
1748
|
|
@@ -1758,12 +1751,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1758
1751
|
.gather(1, cols)
|
1759
1752
|
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
1760
1753
|
) # (bsz, draft_num)
|
1754
|
+
|
1761
1755
|
metadata_expand.page_table.copy_(
|
1762
1756
|
non_masked_page_table.gather(1, sort_order)
|
1763
1757
|
)
|
1764
|
-
metadata_expand.cache_seqlens_int32.copy_(
|
1765
|
-
mask.sum(dim=1).to(torch.int32)
|
1766
|
-
)
|
1758
|
+
metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
|
1767
1759
|
metadata_expand.cu_seqlens_k[1:].copy_(
|
1768
1760
|
torch.cumsum(
|
1769
1761
|
metadata_expand.cache_seqlens_int32,
|
@@ -1771,19 +1763,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1771
1763
|
dtype=torch.int32,
|
1772
1764
|
)
|
1773
1765
|
)
|
1774
|
-
|
1775
|
-
metadata_expand.cache_seqlens_int32.max().item()
|
1776
|
-
)
|
1766
|
+
|
1777
1767
|
elif forward_mode.is_draft_extend():
|
1778
1768
|
metadata = self.draft_extend_metadata[bs]
|
1779
|
-
metadata.cache_seqlens_int32.copy_(seq_lens
|
1769
|
+
metadata.cache_seqlens_int32.copy_(seq_lens)
|
1780
1770
|
|
1781
1771
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
1782
1772
|
metadata.cu_seqlens_k[1:].copy_(
|
1783
1773
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
1784
1774
|
)
|
1785
1775
|
accept_length = spec_info.accept_length[:bs]
|
1786
|
-
|
1776
|
+
if spec_info.accept_length_cpu:
|
1777
|
+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
1778
|
+
else:
|
1779
|
+
metadata.max_seq_len_q = 1
|
1780
|
+
|
1787
1781
|
metadata.cu_seqlens_q[1:].copy_(
|
1788
1782
|
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
1789
1783
|
)
|
@@ -1795,8 +1789,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1795
1789
|
req_pool_indices[:, None],
|
1796
1790
|
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
1797
1791
|
]
|
1798
|
-
page_indices
|
1799
|
-
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
1792
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
1800
1793
|
|
1801
1794
|
if encoder_lens is not None:
|
1802
1795
|
# Only support encoder size 1 for now
|
@@ -1824,7 +1817,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1824
1817
|
|
1825
1818
|
def get_cuda_graph_seq_len_fill_value(self):
|
1826
1819
|
"""Get the fill value for sequence length in CUDA graph."""
|
1827
|
-
return
|
1820
|
+
return 1
|
1828
1821
|
|
1829
1822
|
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
|
1830
1823
|
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
|
@@ -2016,9 +2009,9 @@ class FlashAttentionMultiStepBackend:
|
|
2016
2009
|
for i in range(self.speculative_num_steps - 1):
|
2017
2010
|
self.attn_backends[i].init_forward_metadata(forward_batch)
|
2018
2011
|
|
2019
|
-
def init_cuda_graph_state(self, max_bs: int):
|
2012
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
2020
2013
|
for i in range(self.speculative_num_steps):
|
2021
|
-
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
2014
|
+
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
2022
2015
|
|
2023
2016
|
def init_forward_metadata_capture_cuda_graph(
|
2024
2017
|
self,
|
@@ -2045,6 +2038,8 @@ class FlashAttentionMultiStepBackend:
|
|
2045
2038
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
2046
2039
|
|
2047
2040
|
for i in range(self.speculative_num_steps - 1):
|
2041
|
+
# TODO: incrementally update the metadata for the later steps,
|
2042
|
+
# so that they do not need to recompute everything from scratch.
|
2048
2043
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
2049
2044
|
bs,
|
2050
2045
|
forward_batch.req_pool_indices,
|
@@ -2058,21 +2053,25 @@ class FlashAttentionMultiStepBackend:
|
|
2058
2053
|
)
|
2059
2054
|
|
2060
2055
|
|
2061
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
2056
|
+
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
2057
|
+
# TODO: fuse these kernels
|
2058
|
+
# NOTE: torch.compile makes it slower in speculative decoding
|
2062
2059
|
def normal_decode_set_medadata(
|
2063
|
-
|
2064
|
-
|
2065
|
-
|
2066
|
-
|
2067
|
-
|
2068
|
-
|
2069
|
-
|
2060
|
+
cache_seqlens_int32: torch.Tensor,
|
2061
|
+
cu_seqlens_k: torch.Tensor,
|
2062
|
+
page_table: torch.Tensor,
|
2063
|
+
req_to_token: torch.Tensor,
|
2064
|
+
req_pool_indices: torch.Tensor,
|
2065
|
+
strided_indices: torch.Tensor,
|
2066
|
+
max_seq_pages: torch.Tensor,
|
2067
|
+
seq_lens: torch.Tensor,
|
2068
|
+
seq_len_delta: int,
|
2069
|
+
page_size: int,
|
2070
2070
|
):
|
2071
|
-
|
2072
|
-
|
2071
|
+
cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
|
2072
|
+
cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
|
2073
2073
|
page_indices = req_to_token[
|
2074
2074
|
req_pool_indices[:, None],
|
2075
2075
|
strided_indices[:max_seq_pages][None, :],
|
2076
2076
|
]
|
2077
|
-
|
2078
|
-
metadata.page_table[:, max_seq_pages:].fill_(0)
|
2077
|
+
page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
262
262
|
)
|
263
263
|
|
264
264
|
def init_cuda_graph_state(
|
265
|
-
self,
|
265
|
+
self,
|
266
|
+
max_bs: int,
|
267
|
+
max_num_tokens: int,
|
268
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
266
269
|
):
|
267
270
|
if kv_indices_buf is None:
|
268
271
|
cuda_graph_kv_indices = torch.zeros(
|
269
|
-
(
|
272
|
+
(max_num_tokens * self.max_context_len,),
|
270
273
|
dtype=torch.int32,
|
271
274
|
device="cuda",
|
272
275
|
)
|
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
285
288
|
|
286
289
|
if not self.skip_prefill:
|
287
290
|
self.cuda_graph_custom_mask = torch.zeros(
|
288
|
-
(
|
291
|
+
(max_num_tokens * self.max_context_len),
|
289
292
|
dtype=torch.uint8,
|
290
293
|
device="cuda",
|
291
294
|
)
|
@@ -440,7 +443,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
440
443
|
raise ValueError("Invalid forward mode")
|
441
444
|
|
442
445
|
def get_cuda_graph_seq_len_fill_value(self):
|
443
|
-
return
|
446
|
+
return 1
|
444
447
|
|
445
448
|
def forward_extend(
|
446
449
|
self,
|
@@ -1049,14 +1052,13 @@ class FlashInferMultiStepDraftBackend:
|
|
1049
1052
|
kv_indices_buffer,
|
1050
1053
|
self.kv_indptr,
|
1051
1054
|
forward_batch.positions,
|
1052
|
-
num_seqs,
|
1053
|
-
self.topk,
|
1054
1055
|
self.pool_len,
|
1055
1056
|
kv_indices_buffer.shape[1],
|
1056
1057
|
self.kv_indptr.shape[1],
|
1057
1058
|
next_power_of_2(num_seqs),
|
1058
1059
|
next_power_of_2(self.speculative_num_steps),
|
1059
1060
|
next_power_of_2(bs),
|
1061
|
+
self.page_size,
|
1060
1062
|
)
|
1061
1063
|
|
1062
1064
|
assert forward_batch.spec_info is not None
|
@@ -1097,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1097
1099
|
|
1098
1100
|
self.common_template(forward_batch, kv_indices, call_fn)
|
1099
1101
|
|
1100
|
-
def init_cuda_graph_state(self, max_bs: int):
|
1102
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
1101
1103
|
self.cuda_graph_kv_indices = torch.zeros(
|
1102
1104
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
1103
1105
|
dtype=torch.int32,
|
@@ -1106,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
|
|
1106
1108
|
|
1107
1109
|
for i in range(self.speculative_num_steps):
|
1108
1110
|
self.attn_backends[i].init_cuda_graph_state(
|
1109
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1111
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
1110
1112
|
)
|
1111
1113
|
|
1112
1114
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -15,7 +15,6 @@ from functools import partial
|
|
15
15
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
16
16
|
|
17
17
|
import torch
|
18
|
-
import triton
|
19
18
|
|
20
19
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
21
20
|
import logging
|
@@ -33,7 +32,7 @@ from sglang.srt.layers.utils import is_sm100_supported
|
|
33
32
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
35
34
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
36
|
-
from sglang.srt.utils import is_flashinfer_available
|
35
|
+
from sglang.srt.utils import is_flashinfer_available, next_power_of_2
|
37
36
|
|
38
37
|
if TYPE_CHECKING:
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -200,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
200
199
|
)
|
201
200
|
|
202
201
|
def init_cuda_graph_state(
|
203
|
-
self,
|
202
|
+
self,
|
203
|
+
max_bs: int,
|
204
|
+
max_num_tokens: int,
|
205
|
+
kv_indices_buf: Optional[torch.Tensor] = None,
|
204
206
|
):
|
205
207
|
if kv_indices_buf is None:
|
206
208
|
cuda_graph_kv_indices = torch.zeros(
|
@@ -365,7 +367,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
365
367
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
366
368
|
|
367
369
|
def get_cuda_graph_seq_len_fill_value(self):
|
368
|
-
return
|
370
|
+
return 1
|
369
371
|
|
370
372
|
def forward_extend(
|
371
373
|
self,
|
@@ -756,7 +758,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
756
758
|
|
757
759
|
if topk > 1:
|
758
760
|
raise ValueError(
|
759
|
-
|
761
|
+
"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
760
762
|
)
|
761
763
|
self.topk = topk
|
762
764
|
self.speculative_num_steps = speculative_num_steps
|
@@ -790,6 +792,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
790
792
|
|
791
793
|
# Cached variables for generate_draft_decode_kv_indices
|
792
794
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
795
|
+
self.page_size = model_runner.server_args.page_size
|
793
796
|
|
794
797
|
def common_template(
|
795
798
|
self,
|
@@ -810,14 +813,13 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
810
813
|
kv_indices_buffer,
|
811
814
|
self.kv_indptr,
|
812
815
|
forward_batch.positions,
|
813
|
-
num_seqs,
|
814
|
-
self.topk,
|
815
816
|
self.pool_len,
|
816
817
|
kv_indices_buffer.shape[1],
|
817
818
|
self.kv_indptr.shape[1],
|
818
|
-
|
819
|
-
|
820
|
-
|
819
|
+
next_power_of_2(num_seqs),
|
820
|
+
next_power_of_2(self.speculative_num_steps),
|
821
|
+
next_power_of_2(bs),
|
822
|
+
self.page_size,
|
821
823
|
)
|
822
824
|
|
823
825
|
assert forward_batch.spec_info is not None
|
@@ -853,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
853
855
|
|
854
856
|
self.common_template(forward_batch, kv_indices, call_fn)
|
855
857
|
|
856
|
-
def init_cuda_graph_state(self, max_bs: int):
|
858
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
857
859
|
self.cuda_graph_kv_indices = torch.zeros(
|
858
860
|
(self.speculative_num_steps, max_bs * self.max_context_len),
|
859
861
|
dtype=torch.int32,
|
@@ -862,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
|
862
864
|
|
863
865
|
for i in range(self.speculative_num_steps):
|
864
866
|
self.attn_backends[i].init_cuda_graph_state(
|
865
|
-
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
867
|
+
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
866
868
|
)
|
867
869
|
|
868
870
|
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
@@ -920,19 +922,18 @@ def fast_mla_decode_plan(
|
|
920
922
|
self._page_size = page_size
|
921
923
|
self._sm_scale = sm_scale
|
922
924
|
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
self.
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
925
|
+
try:
|
926
|
+
# Standard version with just the required arguments (no use_profiler)
|
927
|
+
self._cached_module.plan.default(
|
928
|
+
self._float_workspace_buffer,
|
929
|
+
self._int_workspace_buffer,
|
930
|
+
self._pin_memory_int_workspace_buffer,
|
931
|
+
qo_indptr_cpu,
|
932
|
+
kv_indptr_cpu,
|
933
|
+
kv_len_arr_cpu,
|
934
|
+
num_heads,
|
935
|
+
head_dim_ckv,
|
936
|
+
causal,
|
937
|
+
)
|
938
|
+
except Exception as e:
|
939
|
+
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|