sglang 0.4.3.post4__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +124 -665
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +78 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +9 -4
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -11
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/RECORD +124 -79
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,10 @@ More details can be found in https://docs.flashinfer.ai/api/mla.html
|
|
11
11
|
|
12
12
|
from dataclasses import dataclass
|
13
13
|
from functools import partial
|
14
|
-
from typing import TYPE_CHECKING, Optional, Union
|
14
|
+
from typing import TYPE_CHECKING, Callable, Optional, Union
|
15
15
|
|
16
16
|
import torch
|
17
|
+
import triton
|
17
18
|
|
18
19
|
from sglang.global_config import global_config
|
19
20
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -23,6 +24,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
|
23
24
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
24
25
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
25
26
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
27
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
28
|
from sglang.srt.utils import is_flashinfer_available
|
27
29
|
|
28
30
|
if TYPE_CHECKING:
|
@@ -58,12 +60,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
58
60
|
def __init__(
|
59
61
|
self,
|
60
62
|
model_runner: ModelRunner,
|
63
|
+
skip_prefill: bool = False,
|
64
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
65
|
+
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
61
66
|
):
|
62
67
|
super().__init__()
|
63
68
|
|
64
69
|
# Parse constants
|
65
70
|
self.max_context_len = model_runner.model_config.context_len
|
66
71
|
self.device = model_runner.device
|
72
|
+
self.skip_prefill = skip_prefill
|
67
73
|
|
68
74
|
global_config.enable_flashinfer_mla = True
|
69
75
|
|
@@ -78,35 +84,51 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
78
84
|
self.workspace_buffer = global_workspace_buffer
|
79
85
|
|
80
86
|
max_bs = model_runner.req_to_token_pool.size
|
81
|
-
|
82
|
-
|
83
|
-
|
87
|
+
if kv_indptr_buf is None:
|
88
|
+
self.kv_indptr = torch.zeros(
|
89
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
90
|
+
)
|
91
|
+
else:
|
92
|
+
self.kv_indptr = kv_indptr_buf
|
84
93
|
|
85
|
-
|
86
|
-
|
87
|
-
|
94
|
+
if not self.skip_prefill:
|
95
|
+
self.qo_indptr = torch.zeros(
|
96
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
97
|
+
)
|
88
98
|
|
89
|
-
|
90
|
-
|
91
|
-
|
99
|
+
if q_indptr_decode_buf is None:
|
100
|
+
self.q_indptr_decode = torch.arange(
|
101
|
+
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
102
|
+
)
|
103
|
+
else:
|
104
|
+
self.q_indptr_decode = q_indptr_decode_buf
|
92
105
|
|
93
106
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
94
107
|
self.workspace_buffer, "NHD"
|
95
108
|
)
|
96
109
|
|
97
|
-
self.
|
98
|
-
self.
|
99
|
-
|
100
|
-
|
110
|
+
if not self.skip_prefill:
|
111
|
+
self.prefill_wrapper_paged = BatchMLAPagedAttentionWrapper(
|
112
|
+
self.workspace_buffer,
|
113
|
+
backend="auto",
|
114
|
+
)
|
115
|
+
|
116
|
+
# FlashinferMLA backend uses mla wrapper for target verify
|
117
|
+
self.prefill_wrapper_verify = BatchMLAPagedAttentionWrapper(
|
118
|
+
self.workspace_buffer,
|
119
|
+
backend="auto",
|
120
|
+
)
|
101
121
|
|
102
122
|
self.decode_wrapper = BatchMLAPagedAttentionWrapper(
|
103
123
|
self.workspace_buffer, backend="auto"
|
104
124
|
)
|
105
125
|
|
106
126
|
# Create indices updater
|
107
|
-
|
108
|
-
|
109
|
-
|
127
|
+
if not skip_prefill:
|
128
|
+
self.indices_updater_prefill = FlashInferMLAIndicesUpdaterPrefill(
|
129
|
+
model_runner, self
|
130
|
+
)
|
131
|
+
|
110
132
|
self.indices_updater_decode = FlashInferMLAIndicesUpdaterDecode(
|
111
133
|
model_runner, self
|
112
134
|
)
|
@@ -114,7 +136,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
114
136
|
# Other metadata
|
115
137
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
116
138
|
self.decode_cuda_graph_metadata = {}
|
117
|
-
self.prefill_cuda_graph_metadata = {}
|
139
|
+
self.prefill_cuda_graph_metadata = {} # For verify
|
118
140
|
|
119
141
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
120
142
|
if forward_batch.forward_mode.is_decode_or_idle():
|
@@ -126,6 +148,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
126
148
|
init_metadata_replay=False,
|
127
149
|
)
|
128
150
|
self.forward_metadata = DecodeMetadata(self.decode_wrapper)
|
151
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
152
|
+
self.indices_updater_prefill.update(
|
153
|
+
forward_batch.req_pool_indices,
|
154
|
+
forward_batch.seq_lens,
|
155
|
+
forward_batch.seq_lens_sum,
|
156
|
+
prefix_lens=None,
|
157
|
+
prefill_wrapper_paged=self.prefill_wrapper_paged,
|
158
|
+
use_ragged=False,
|
159
|
+
spec_info=forward_batch.spec_info,
|
160
|
+
)
|
161
|
+
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_paged, False)
|
162
|
+
elif forward_batch.forward_mode.is_target_verify():
|
163
|
+
self.indices_updater_prefill.update(
|
164
|
+
forward_batch.req_pool_indices,
|
165
|
+
forward_batch.seq_lens,
|
166
|
+
forward_batch.seq_lens_sum,
|
167
|
+
prefix_lens=None,
|
168
|
+
prefill_wrapper_paged=self.prefill_wrapper_verify,
|
169
|
+
use_ragged=False,
|
170
|
+
spec_info=forward_batch.spec_info,
|
171
|
+
)
|
172
|
+
self.forward_metadata = PrefillMetadata(self.prefill_wrapper_verify, False)
|
129
173
|
else:
|
130
174
|
prefix_lens = forward_batch.extend_prefix_lens
|
131
175
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
@@ -202,10 +246,33 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
202
246
|
seq_lens_sum,
|
203
247
|
decode_wrapper=decode_wrapper,
|
204
248
|
init_metadata_replay=False,
|
249
|
+
spec_info=spec_info,
|
205
250
|
)
|
206
251
|
self.decode_cuda_graph_metadata[bs] = decode_wrapper
|
207
252
|
self.forward_metadata = DecodeMetadata(decode_wrapper)
|
208
253
|
decode_wrapper.plan = partial(fast_mla_decode_plan, decode_wrapper)
|
254
|
+
elif forward_mode.is_target_verify():
|
255
|
+
verify_wrapper = BatchMLAPagedAttentionWrapper(
|
256
|
+
self.workspace_buffer,
|
257
|
+
use_cuda_graph=True,
|
258
|
+
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
|
259
|
+
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
|
260
|
+
kv_indices=self.cuda_graph_kv_indices,
|
261
|
+
kv_len_arr=self.cuda_graph_kv_lens[:bs],
|
262
|
+
backend="auto",
|
263
|
+
)
|
264
|
+
seq_lens_sum = seq_lens.sum().item()
|
265
|
+
self.indices_updater_prefill.update(
|
266
|
+
req_pool_indices,
|
267
|
+
seq_lens,
|
268
|
+
seq_lens_sum,
|
269
|
+
prefix_lens=None,
|
270
|
+
prefill_wrapper_paged=verify_wrapper,
|
271
|
+
use_ragged=False,
|
272
|
+
spec_info=spec_info,
|
273
|
+
)
|
274
|
+
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
|
275
|
+
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
|
209
276
|
else:
|
210
277
|
raise ValueError(f"Invalid mode: {forward_mode=}")
|
211
278
|
|
@@ -221,6 +288,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
221
288
|
seq_lens_cpu: Optional[torch.Tensor],
|
222
289
|
):
|
223
290
|
if forward_mode.is_decode_or_idle():
|
291
|
+
assert seq_lens_cpu is not None
|
224
292
|
kv_len_arr_cpu = seq_lens_cpu[:bs]
|
225
293
|
self.cuda_graph_kv_indptr_cpu[1 : bs + 1] = torch.cumsum(
|
226
294
|
kv_len_arr_cpu, dim=0
|
@@ -239,8 +307,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
239
307
|
seq_lens_sum,
|
240
308
|
decode_wrapper=self.decode_cuda_graph_metadata[bs],
|
241
309
|
init_metadata_replay=True,
|
310
|
+
spec_info=spec_info,
|
242
311
|
**self.fast_decode_kwargs,
|
243
312
|
)
|
313
|
+
elif forward_mode.is_target_verify():
|
314
|
+
self.indices_updater_prefill.update(
|
315
|
+
req_pool_indices[:bs],
|
316
|
+
seq_lens[:bs],
|
317
|
+
seq_lens_sum,
|
318
|
+
prefix_lens=None,
|
319
|
+
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
|
320
|
+
use_ragged=False,
|
321
|
+
spec_info=spec_info,
|
322
|
+
)
|
244
323
|
else:
|
245
324
|
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
246
325
|
|
@@ -254,7 +333,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
254
333
|
v: torch.Tensor,
|
255
334
|
layer: RadixAttention,
|
256
335
|
forward_batch: ForwardBatch,
|
257
|
-
save_kv_cache=True,
|
336
|
+
save_kv_cache: bool = True,
|
258
337
|
):
|
259
338
|
|
260
339
|
cache_loc = forward_batch.out_cache_loc
|
@@ -297,7 +376,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
|
297
376
|
v: torch.Tensor,
|
298
377
|
layer: RadixAttention,
|
299
378
|
forward_batch: ForwardBatch,
|
300
|
-
save_kv_cache=True,
|
379
|
+
save_kv_cache: bool = True,
|
301
380
|
):
|
302
381
|
decode_wrapper = self.forward_metadata.decode_wrapper
|
303
382
|
cache_loc = forward_batch.out_cache_loc
|
@@ -349,6 +428,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
349
428
|
seq_lens_sum: int,
|
350
429
|
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
351
430
|
init_metadata_replay: bool = False,
|
431
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
352
432
|
**fast_decode_kwargs,
|
353
433
|
):
|
354
434
|
decode_wrapper = decode_wrapper or self.decode_wrapper
|
@@ -360,6 +440,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
360
440
|
self.q_indptr,
|
361
441
|
self.kv_indptr,
|
362
442
|
init_metadata_replay,
|
443
|
+
spec_info,
|
363
444
|
**fast_decode_kwargs,
|
364
445
|
)
|
365
446
|
|
@@ -372,30 +453,33 @@ class FlashInferMLAIndicesUpdaterDecode:
|
|
372
453
|
q_indptr: torch.Tensor,
|
373
454
|
kv_indptr: torch.Tensor,
|
374
455
|
init_metadata_replay: bool = False,
|
456
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
375
457
|
**fast_decode_kwargs,
|
376
458
|
):
|
377
459
|
bs = len(req_pool_indices)
|
378
460
|
q_indptr = q_indptr[: bs + 1]
|
379
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
380
|
-
kv_indptr = kv_indptr[: bs + 1]
|
381
|
-
kv_indices = (
|
382
|
-
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
383
|
-
if not init_metadata_replay
|
384
|
-
else fast_decode_kwargs["kv_indices"]
|
385
|
-
)
|
386
|
-
|
387
461
|
kv_lens = paged_kernel_lens.to(torch.int32)
|
388
462
|
sm_scale = self.scaling
|
463
|
+
if spec_info is None:
|
464
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
465
|
+
kv_indptr = kv_indptr[: bs + 1]
|
466
|
+
kv_indices = (
|
467
|
+
torch.empty(paged_kernel_lens_sum, dtype=torch.int32, device="cuda")
|
468
|
+
if not init_metadata_replay
|
469
|
+
else fast_decode_kwargs["kv_indices"]
|
470
|
+
)
|
471
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
472
|
+
self.req_to_token,
|
473
|
+
req_pool_indices,
|
474
|
+
paged_kernel_lens,
|
475
|
+
kv_indptr,
|
476
|
+
None,
|
477
|
+
kv_indices,
|
478
|
+
self.req_to_token.shape[1],
|
479
|
+
)
|
480
|
+
else:
|
481
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
389
482
|
|
390
|
-
create_flashinfer_kv_indices_triton[(bs,)](
|
391
|
-
self.req_to_token,
|
392
|
-
req_pool_indices,
|
393
|
-
paged_kernel_lens,
|
394
|
-
kv_indptr,
|
395
|
-
None,
|
396
|
-
kv_indices,
|
397
|
-
self.req_to_token.shape[1],
|
398
|
-
)
|
399
483
|
if not init_metadata_replay:
|
400
484
|
wrapper.plan(
|
401
485
|
q_indptr,
|
@@ -457,6 +541,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
457
541
|
prefix_lens: torch.Tensor,
|
458
542
|
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
459
543
|
use_ragged: bool,
|
544
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
460
545
|
):
|
461
546
|
if use_ragged:
|
462
547
|
paged_kernel_lens = prefix_lens
|
@@ -476,6 +561,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
476
561
|
self.kv_indptr,
|
477
562
|
self.qo_indptr,
|
478
563
|
use_ragged,
|
564
|
+
spec_info,
|
479
565
|
)
|
480
566
|
|
481
567
|
def call_begin_forward(
|
@@ -490,29 +576,46 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
490
576
|
kv_indptr: torch.Tensor,
|
491
577
|
qo_indptr: torch.Tensor,
|
492
578
|
use_ragged: bool,
|
579
|
+
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
493
580
|
):
|
494
|
-
bs = len(
|
495
|
-
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
496
|
-
kv_indptr = kv_indptr[: bs + 1]
|
497
|
-
kv_indices = torch.empty(
|
498
|
-
paged_kernel_lens_sum,
|
499
|
-
dtype=torch.int32,
|
500
|
-
device=req_pool_indices.device,
|
501
|
-
)
|
502
|
-
create_flashinfer_kv_indices_triton[(bs,)](
|
503
|
-
self.req_to_token,
|
504
|
-
req_pool_indices,
|
505
|
-
paged_kernel_lens,
|
506
|
-
kv_indptr,
|
507
|
-
None,
|
508
|
-
kv_indices,
|
509
|
-
self.req_to_token.shape[1],
|
510
|
-
)
|
511
|
-
|
512
|
-
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
513
|
-
qo_indptr = qo_indptr[: bs + 1]
|
581
|
+
bs = len(seq_lens)
|
514
582
|
sm_scale = self.scaling
|
515
583
|
|
584
|
+
if spec_info is None:
|
585
|
+
assert len(seq_lens) == len(req_pool_indices)
|
586
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
587
|
+
kv_indptr = kv_indptr[: bs + 1]
|
588
|
+
kv_indices = torch.empty(
|
589
|
+
paged_kernel_lens_sum,
|
590
|
+
dtype=torch.int32,
|
591
|
+
device=req_pool_indices.device,
|
592
|
+
)
|
593
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
594
|
+
self.req_to_token,
|
595
|
+
req_pool_indices,
|
596
|
+
paged_kernel_lens,
|
597
|
+
kv_indptr,
|
598
|
+
None,
|
599
|
+
kv_indices,
|
600
|
+
self.req_to_token.shape[1],
|
601
|
+
)
|
602
|
+
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
603
|
+
qo_indptr = qo_indptr[: bs + 1]
|
604
|
+
custom_mask = None
|
605
|
+
else:
|
606
|
+
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
607
|
+
spec_info, EagleVerifyInput
|
608
|
+
)
|
609
|
+
# TODO: Support topk > 1 with custom mask
|
610
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
611
|
+
spec_info.generate_attn_arg_prefill(
|
612
|
+
req_pool_indices,
|
613
|
+
paged_kernel_lens,
|
614
|
+
paged_kernel_lens_sum,
|
615
|
+
self.req_to_token,
|
616
|
+
)
|
617
|
+
)
|
618
|
+
|
516
619
|
if use_ragged:
|
517
620
|
# ragged prefill
|
518
621
|
wrapper_ragged.begin_forward(
|
@@ -543,6 +646,163 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
|
543
646
|
)
|
544
647
|
|
545
648
|
|
649
|
+
class FlashInferMLAMultiStepDraftBackend:
|
650
|
+
"""
|
651
|
+
Wrap multiple flashinfer mla attention backends as one for multiple consecutive
|
652
|
+
draft decoding steps.
|
653
|
+
"""
|
654
|
+
|
655
|
+
def __init__(
|
656
|
+
self,
|
657
|
+
model_runner: ModelRunner,
|
658
|
+
topk: int,
|
659
|
+
speculative_num_steps: int,
|
660
|
+
):
|
661
|
+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
662
|
+
|
663
|
+
if topk > 1:
|
664
|
+
raise ValueError(
|
665
|
+
f"Currently Flashinfer MLA only supports topk=1 for speculative decoding"
|
666
|
+
)
|
667
|
+
self.topk = topk
|
668
|
+
self.speculative_num_steps = speculative_num_steps
|
669
|
+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
670
|
+
|
671
|
+
max_bs = model_runner.req_to_token_pool.size * self.topk
|
672
|
+
self.kv_indptr = torch.zeros(
|
673
|
+
(
|
674
|
+
self.speculative_num_steps,
|
675
|
+
max_bs + 1,
|
676
|
+
),
|
677
|
+
dtype=torch.int32,
|
678
|
+
device=model_runner.device,
|
679
|
+
)
|
680
|
+
self.q_indptr_decode = torch.arange(
|
681
|
+
0, max_bs + 1, dtype=torch.int32, device=model_runner.device
|
682
|
+
)
|
683
|
+
|
684
|
+
self.attn_backends = []
|
685
|
+
for i in range(self.speculative_num_steps):
|
686
|
+
self.attn_backends.append(
|
687
|
+
FlashInferMLAAttnBackend(
|
688
|
+
model_runner,
|
689
|
+
skip_prefill=True,
|
690
|
+
kv_indptr_buf=self.kv_indptr[i],
|
691
|
+
q_indptr_decode_buf=self.q_indptr_decode,
|
692
|
+
)
|
693
|
+
)
|
694
|
+
|
695
|
+
self.max_context_len = self.attn_backends[0].max_context_len
|
696
|
+
|
697
|
+
# Cached variables for generate_draft_decode_kv_indices
|
698
|
+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
699
|
+
|
700
|
+
def common_template(
|
701
|
+
self,
|
702
|
+
forward_batch: ForwardBatch,
|
703
|
+
kv_indices_buffer: torch.Tensor,
|
704
|
+
call_fn: Callable,
|
705
|
+
):
|
706
|
+
num_seqs = forward_batch.batch_size
|
707
|
+
bs = self.topk * num_seqs
|
708
|
+
seq_lens_sum = forward_batch.seq_lens_sum
|
709
|
+
|
710
|
+
self.generate_draft_decode_kv_indices[
|
711
|
+
(self.speculative_num_steps, num_seqs, self.topk)
|
712
|
+
](
|
713
|
+
forward_batch.req_pool_indices,
|
714
|
+
forward_batch.req_to_token_pool.req_to_token,
|
715
|
+
forward_batch.seq_lens,
|
716
|
+
kv_indices_buffer,
|
717
|
+
self.kv_indptr,
|
718
|
+
forward_batch.positions,
|
719
|
+
num_seqs,
|
720
|
+
self.topk,
|
721
|
+
self.pool_len,
|
722
|
+
kv_indices_buffer.shape[1],
|
723
|
+
self.kv_indptr.shape[1],
|
724
|
+
triton.next_power_of_2(num_seqs),
|
725
|
+
triton.next_power_of_2(self.speculative_num_steps),
|
726
|
+
triton.next_power_of_2(bs),
|
727
|
+
)
|
728
|
+
|
729
|
+
assert forward_batch.spec_info is not None
|
730
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
731
|
+
|
732
|
+
for i in range(self.speculative_num_steps - 1):
|
733
|
+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
734
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
735
|
+
: seq_lens_sum * self.topk + bs * (i + 1)
|
736
|
+
]
|
737
|
+
call_fn(i, forward_batch)
|
738
|
+
|
739
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
740
|
+
kv_indices = torch.zeros(
|
741
|
+
(
|
742
|
+
self.speculative_num_steps,
|
743
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
744
|
+
),
|
745
|
+
dtype=torch.int32,
|
746
|
+
device="cuda",
|
747
|
+
)
|
748
|
+
|
749
|
+
def call_fn(i, forward_batch):
|
750
|
+
assert forward_batch.spec_info is not None
|
751
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
752
|
+
forward_batch.spec_info.kv_indptr = (
|
753
|
+
forward_batch.spec_info.kv_indptr.clone()
|
754
|
+
)
|
755
|
+
forward_batch.spec_info.kv_indices = (
|
756
|
+
forward_batch.spec_info.kv_indices.clone()
|
757
|
+
)
|
758
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
759
|
+
|
760
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
761
|
+
|
762
|
+
def init_cuda_graph_state(self, max_bs: int):
|
763
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
764
|
+
(self.speculative_num_steps, max_bs * self.max_context_len),
|
765
|
+
dtype=torch.int32,
|
766
|
+
device="cuda",
|
767
|
+
)
|
768
|
+
|
769
|
+
for i in range(self.speculative_num_steps):
|
770
|
+
self.attn_backends[i].init_cuda_graph_state(
|
771
|
+
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
772
|
+
)
|
773
|
+
|
774
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
775
|
+
def call_fn(i, forward_batch):
|
776
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
777
|
+
forward_batch.batch_size,
|
778
|
+
forward_batch.batch_size * self.topk,
|
779
|
+
forward_batch.req_pool_indices,
|
780
|
+
forward_batch.seq_lens,
|
781
|
+
encoder_lens=None,
|
782
|
+
forward_mode=ForwardMode.DECODE,
|
783
|
+
spec_info=forward_batch.spec_info,
|
784
|
+
)
|
785
|
+
|
786
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
787
|
+
|
788
|
+
def init_forward_metadata_replay_cuda_graph(
|
789
|
+
self, forward_batch: ForwardBatch, bs: int
|
790
|
+
):
|
791
|
+
def call_fn(i, forward_batch):
|
792
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
793
|
+
bs,
|
794
|
+
forward_batch.req_pool_indices,
|
795
|
+
forward_batch.seq_lens,
|
796
|
+
seq_lens_sum=-1,
|
797
|
+
encoder_lens=None,
|
798
|
+
forward_mode=ForwardMode.DECODE,
|
799
|
+
spec_info=forward_batch.spec_info,
|
800
|
+
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
|
801
|
+
)
|
802
|
+
|
803
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
804
|
+
|
805
|
+
|
546
806
|
def fast_mla_decode_plan(
|
547
807
|
self,
|
548
808
|
qo_indptr_cpu: torch.Tensor,
|
@@ -6,9 +6,7 @@ import torch
|
|
6
6
|
import triton
|
7
7
|
|
8
8
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
9
|
-
from sglang.srt.layers.attention.
|
10
|
-
create_flashinfer_kv_indices_triton,
|
11
|
-
)
|
9
|
+
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
12
10
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
11
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
14
12
|
|
@@ -27,7 +27,7 @@ import triton.language as tl
|
|
27
27
|
|
28
28
|
from sglang.srt.utils import is_hip
|
29
29
|
|
30
|
-
|
30
|
+
_is_hip = is_hip()
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
|
|
180
180
|
):
|
181
181
|
BLOCK = 64
|
182
182
|
# [TODO] work around SGPR limit on MI3xx
|
183
|
-
if
|
183
|
+
if _is_hip:
|
184
184
|
BLOCK = 8
|
185
185
|
NUM_KV_SPLITS = num_kv_splits
|
186
186
|
Lk = k_buffer.shape[-1]
|
@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
|
|
195
195
|
num_warps = 4
|
196
196
|
else:
|
197
197
|
num_warps = 2
|
198
|
-
if
|
198
|
+
if _is_hip:
|
199
199
|
num_warps = 1
|
200
200
|
|
201
201
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
|
|
406
406
|
Lv = v_buffer.shape[-1]
|
407
407
|
|
408
408
|
# [TODO] work around shmem limit on MI3xx
|
409
|
-
if
|
409
|
+
if _is_hip and Lk >= 576:
|
410
410
|
BLOCK = 16
|
411
411
|
|
412
412
|
if Lk == 576:
|
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
|
|
433
433
|
|
434
434
|
extra_kargs = {}
|
435
435
|
num_stages = 2
|
436
|
-
if
|
436
|
+
if _is_hip:
|
437
437
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
438
438
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
439
439
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
|
|
546
546
|
NUM_KV_SPLITS = num_kv_splits
|
547
547
|
|
548
548
|
extra_kargs = {}
|
549
|
-
if
|
549
|
+
if _is_hip:
|
550
550
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
551
551
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
552
552
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
|
|
9
9
|
if is_cuda_available:
|
10
10
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
11
11
|
|
12
|
-
|
12
|
+
_is_hip = is_hip()
|
13
13
|
|
14
14
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
15
15
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
|
|
1032
1032
|
BLOCK_DPE = 0
|
1033
1033
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
1034
1034
|
|
1035
|
-
if
|
1035
|
+
if _is_hip:
|
1036
1036
|
BLOCK_M, BLOCK_N = (64, 64)
|
1037
1037
|
num_warps = 4
|
1038
1038
|
|
@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
|
|
1062
1062
|
num_stages = 1
|
1063
1063
|
|
1064
1064
|
extra_kargs = {}
|
1065
|
-
if
|
1065
|
+
if _is_hip:
|
1066
1066
|
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
|
1067
1067
|
|
1068
1068
|
_fwd_kernel[grid](
|
@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
|
|
29
29
|
if is_cuda_available:
|
30
30
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
31
31
|
|
32
|
-
|
32
|
+
_is_hip = is_hip()
|
33
33
|
|
34
34
|
|
35
35
|
@triton.jit
|
@@ -330,7 +330,7 @@ def extend_attention_fwd(
|
|
330
330
|
BLOCK_DPE = 0
|
331
331
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
332
332
|
|
333
|
-
if
|
333
|
+
if _is_hip:
|
334
334
|
BLOCK_M, BLOCK_N = (64, 64)
|
335
335
|
num_warps = 4
|
336
336
|
|
@@ -364,7 +364,7 @@ def extend_attention_fwd(
|
|
364
364
|
num_stages = 1
|
365
365
|
|
366
366
|
extra_kargs = {}
|
367
|
-
if
|
367
|
+
if _is_hip:
|
368
368
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|
369
369
|
|
370
370
|
_fwd_kernel[grid](
|
@@ -403,7 +403,7 @@ def extend_attention_fwd(
|
|
403
403
|
Lv=Lv,
|
404
404
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
405
405
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
406
|
-
STORE_TRANSPOSE=
|
406
|
+
STORE_TRANSPOSE=_is_hip,
|
407
407
|
num_warps=num_warps,
|
408
408
|
num_stages=num_stages,
|
409
409
|
**extra_kargs,
|
@@ -32,7 +32,7 @@ def is_hip():
|
|
32
32
|
return triton.runtime.driver.active.get_current_target().backend == "hip"
|
33
33
|
|
34
34
|
|
35
|
-
|
35
|
+
_is_hip = is_hip()
|
36
36
|
|
37
37
|
|
38
38
|
@triton.jit
|
@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
|
|
333
333
|
BLOCK = 32
|
334
334
|
|
335
335
|
# # [TODO] work around shmem limit on MI3xx
|
336
|
-
# if
|
336
|
+
# if _is_hip and kv_lora_rank >= 576:
|
337
337
|
# BLOCK = 16
|
338
338
|
|
339
339
|
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
|
@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
|
|
353
353
|
|
354
354
|
extra_kargs = {}
|
355
355
|
num_stages = 2
|
356
|
-
if
|
356
|
+
if _is_hip:
|
357
357
|
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
358
358
|
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
359
359
|
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
|