sglang 0.4.2.post3__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/check_env.py +1 -0
- sglang/global_config.py +2 -0
- sglang/srt/constrained/outlines_backend.py +4 -1
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/layers/attention/flashinfer_backend.py +265 -147
- sglang/srt/layers/attention/triton_backend.py +358 -72
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/linear.py +12 -5
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +51 -5
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
- sglang/srt/layers/quantization/fp8_kernel.py +123 -17
- sglang/srt/layers/quantization/fp8_utils.py +33 -4
- sglang/srt/lora/backend/__init__.py +25 -5
- sglang/srt/lora/backend/base_backend.py +31 -9
- sglang/srt/lora/backend/flashinfer_backend.py +41 -4
- sglang/srt/lora/backend/triton_backend.py +34 -4
- sglang/srt/lora/layers.py +293 -0
- sglang/srt/lora/lora.py +101 -326
- sglang/srt/lora/lora_manager.py +101 -269
- sglang/srt/lora/mem_pool.py +174 -0
- sglang/srt/lora/triton_ops/__init__.py +7 -1
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +170 -0
- sglang/srt/lora/triton_ops/qkv_lora_b.py +5 -5
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +2 -2
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +2 -2
- sglang/srt/lora/utils.py +141 -0
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +16 -3
- sglang/srt/managers/scheduler.py +29 -0
- sglang/srt/managers/tokenizer_manager.py +6 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
- sglang/srt/model_executor/cuda_graph_runner.py +16 -1
- sglang/srt/model_executor/model_runner.py +12 -2
- sglang/srt/models/deepseek_v2.py +17 -7
- sglang/srt/server_args.py +20 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +64 -21
- sglang/srt/speculative/eagle_worker.py +29 -8
- sglang/srt/utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/METADATA +6 -5
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/RECORD +88 -55
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.post3.dist-info → sglang-0.4.3.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
from typing import TYPE_CHECKING, Optional
|
4
4
|
|
5
5
|
import torch
|
6
|
+
import triton
|
6
7
|
|
7
8
|
from sglang.srt.layers.attention import AttentionBackend
|
8
9
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
@@ -18,7 +19,12 @@ if TYPE_CHECKING:
|
|
18
19
|
|
19
20
|
|
20
21
|
class TritonAttnBackend(AttentionBackend):
|
21
|
-
def __init__(
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
model_runner: ModelRunner,
|
25
|
+
skip_prefill: bool = False,
|
26
|
+
kv_indptr_buf: Optional[torch.Tensor] = None,
|
27
|
+
):
|
22
28
|
# Lazy import to avoid the initialization of cuda context
|
23
29
|
from sglang.srt.layers.attention.triton_ops.decode_attention import (
|
24
30
|
decode_attention_fwd,
|
@@ -32,14 +38,29 @@ class TritonAttnBackend(AttentionBackend):
|
|
32
38
|
self.decode_attention_fwd = decode_attention_fwd
|
33
39
|
self.extend_attention_fwd = extend_attention_fwd
|
34
40
|
|
41
|
+
self.skip_prefill = skip_prefill
|
42
|
+
|
35
43
|
max_bs = model_runner.req_to_token_pool.size
|
36
|
-
|
37
|
-
|
38
|
-
|
44
|
+
|
45
|
+
if kv_indptr_buf is None:
|
46
|
+
self.kv_indptr = torch.zeros(
|
47
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
48
|
+
)
|
49
|
+
else:
|
50
|
+
self.kv_indptr = kv_indptr_buf
|
51
|
+
|
39
52
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
40
|
-
|
41
|
-
|
42
|
-
|
53
|
+
|
54
|
+
if not self.skip_prefill:
|
55
|
+
self.qo_indptr = torch.zeros(
|
56
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
57
|
+
)
|
58
|
+
|
59
|
+
self.mask_indptr = torch.zeros(
|
60
|
+
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
61
|
+
)
|
62
|
+
|
63
|
+
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
43
64
|
|
44
65
|
self.num_head = (
|
45
66
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
@@ -50,7 +71,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
50
71
|
|
51
72
|
self.forward_metadata = None
|
52
73
|
|
53
|
-
self.
|
74
|
+
self.max_context_len = model_runner.model_config.context_len
|
54
75
|
|
55
76
|
self.device = model_runner.device
|
56
77
|
|
@@ -59,11 +80,31 @@ class TritonAttnBackend(AttentionBackend):
|
|
59
80
|
|
60
81
|
bs = forward_batch.batch_size
|
61
82
|
kv_indptr = self.kv_indptr
|
62
|
-
|
63
|
-
|
64
|
-
|
83
|
+
spec_info = forward_batch.spec_info
|
84
|
+
|
85
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
86
|
+
if spec_info is None:
|
87
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
88
|
+
kv_indptr = kv_indptr[: bs + 1]
|
89
|
+
kv_indices = torch.zeros(
|
90
|
+
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
91
|
+
)
|
92
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
93
|
+
self.req_to_token,
|
94
|
+
forward_batch.req_pool_indices,
|
95
|
+
forward_batch.seq_lens,
|
96
|
+
kv_indptr,
|
97
|
+
None,
|
98
|
+
kv_indices,
|
99
|
+
self.req_to_token.stride(0),
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
103
|
+
bs = kv_indptr.shape[0] - 1
|
104
|
+
|
105
|
+
attn_logits = torch.zeros(
|
65
106
|
(
|
66
|
-
|
107
|
+
bs,
|
67
108
|
self.num_head,
|
68
109
|
self.num_kv_splits,
|
69
110
|
self.v_head_dim + 1,
|
@@ -72,12 +113,24 @@ class TritonAttnBackend(AttentionBackend):
|
|
72
113
|
device=self.device,
|
73
114
|
)
|
74
115
|
|
116
|
+
qo_indptr = None
|
117
|
+
custom_mask = None
|
118
|
+
mask_indptr = None
|
75
119
|
max_extend_len = None
|
76
|
-
|
120
|
+
elif forward_batch.forward_mode.is_target_verify():
|
121
|
+
bs = len(forward_batch.req_pool_indices)
|
122
|
+
qo_indptr = torch.arange(
|
123
|
+
0,
|
124
|
+
(1 + bs) * self.num_draft_tokens,
|
125
|
+
step=self.num_draft_tokens,
|
126
|
+
dtype=torch.int32,
|
127
|
+
device=self.device,
|
128
|
+
)
|
129
|
+
# Different with flashinfer kv_indptr and kv_indices construction
|
77
130
|
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
78
131
|
kv_indptr = kv_indptr[: bs + 1]
|
79
|
-
kv_indices = torch.
|
80
|
-
|
132
|
+
kv_indices = torch.zeros(
|
133
|
+
kv_indptr[-1], dtype=torch.int32, device=self.device
|
81
134
|
)
|
82
135
|
create_flashinfer_kv_indices_triton[(bs,)](
|
83
136
|
self.req_to_token,
|
@@ -89,15 +142,32 @@ class TritonAttnBackend(AttentionBackend):
|
|
89
142
|
self.req_to_token.stride(0),
|
90
143
|
)
|
91
144
|
|
92
|
-
|
93
|
-
|
94
|
-
|
145
|
+
custom_mask = spec_info.custom_mask
|
146
|
+
seq_mask_len = self.num_draft_tokens * (
|
147
|
+
forward_batch.seq_lens + self.num_draft_tokens
|
148
|
+
)
|
149
|
+
mask_indptr = self.mask_indptr
|
150
|
+
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len[:bs], dim=0)
|
151
|
+
mask_indptr = mask_indptr[: bs + 1]
|
152
|
+
max_extend_len = self.num_draft_tokens
|
153
|
+
attn_logits = None
|
154
|
+
elif forward_batch.forward_mode.is_draft_extend():
|
155
|
+
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
156
|
+
spec_info.generate_attn_arg_prefill(
|
157
|
+
forward_batch.req_pool_indices,
|
158
|
+
forward_batch.seq_lens,
|
159
|
+
self.req_to_token,
|
160
|
+
)
|
161
|
+
)
|
162
|
+
mask_indptr = None
|
163
|
+
max_extend_len = torch.max(spec_info.accept_length).item()
|
164
|
+
attn_logits = None
|
95
165
|
else:
|
96
166
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
97
167
|
forward_batch.extend_prefix_lens, dim=0
|
98
168
|
)
|
99
169
|
kv_indptr = kv_indptr[: bs + 1]
|
100
|
-
kv_indices = torch.
|
170
|
+
kv_indices = torch.zeros(
|
101
171
|
forward_batch.extend_prefix_lens.sum().item(),
|
102
172
|
dtype=torch.int32,
|
103
173
|
device=self.device,
|
@@ -116,8 +186,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
116
186
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
117
187
|
qo_indptr = qo_indptr[: bs + 1]
|
118
188
|
custom_mask = None
|
119
|
-
|
120
|
-
|
189
|
+
mask_indptr = None
|
121
190
|
attn_logits = None
|
122
191
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
123
192
|
|
@@ -128,25 +197,32 @@ class TritonAttnBackend(AttentionBackend):
|
|
128
197
|
kv_indices,
|
129
198
|
qo_indptr,
|
130
199
|
custom_mask,
|
131
|
-
|
200
|
+
mask_indptr,
|
132
201
|
)
|
133
202
|
|
134
|
-
def init_cuda_graph_state(
|
135
|
-
self
|
136
|
-
|
137
|
-
self.
|
138
|
-
(max_bs,), dtype=torch.int32, device=self.device
|
139
|
-
)
|
140
|
-
self.cuda_graph_attn_logits = torch.empty(
|
203
|
+
def init_cuda_graph_state(
|
204
|
+
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
205
|
+
):
|
206
|
+
self.cuda_graph_attn_logits = torch.zeros(
|
141
207
|
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
142
208
|
dtype=torch.float32,
|
143
209
|
device=self.device,
|
144
210
|
)
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
211
|
+
if kv_indices_buf is None:
|
212
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
213
|
+
(max_bs * self.max_context_len),
|
214
|
+
dtype=torch.int32,
|
215
|
+
device=self.device,
|
216
|
+
)
|
217
|
+
else:
|
218
|
+
self.cuda_graph_kv_indices = kv_indices_buf
|
219
|
+
|
220
|
+
if not self.skip_prefill:
|
221
|
+
self.cuda_graph_custom_mask = torch.zeros(
|
222
|
+
(max_bs * self.max_context_len),
|
223
|
+
dtype=torch.uint8,
|
224
|
+
device=self.device,
|
225
|
+
)
|
150
226
|
|
151
227
|
def init_forward_metadata_capture_cuda_graph(
|
152
228
|
self,
|
@@ -159,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
|
|
159
235
|
spec_info: Optional[SpecInfo],
|
160
236
|
):
|
161
237
|
assert encoder_lens is None, "Not supported"
|
162
|
-
assert forward_mode.is_decode(), "Not supported"
|
163
|
-
assert spec_info is None, "Not supported"
|
164
238
|
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
239
|
+
if forward_mode.is_decode_or_idle():
|
240
|
+
if spec_info is None:
|
241
|
+
kv_indptr = self.kv_indptr
|
242
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
243
|
+
kv_indptr = kv_indptr[: bs + 1]
|
244
|
+
kv_indices = self.cuda_graph_kv_indices
|
245
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
246
|
+
self.req_to_token,
|
247
|
+
req_pool_indices,
|
248
|
+
seq_lens,
|
249
|
+
kv_indptr,
|
250
|
+
None,
|
251
|
+
kv_indices,
|
252
|
+
self.req_to_token.stride(0),
|
253
|
+
)
|
254
|
+
else:
|
255
|
+
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
256
|
+
|
257
|
+
attn_logits = self.cuda_graph_attn_logits
|
258
|
+
max_extend_len = None
|
259
|
+
qo_indptr = None
|
260
|
+
custom_mask = None
|
261
|
+
mask_indptr = None
|
262
|
+
elif forward_mode.is_target_verify():
|
263
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
264
|
+
qo_indptr[: bs + 1] = torch.arange(
|
265
|
+
0,
|
266
|
+
(1 + bs) * self.num_draft_tokens,
|
267
|
+
step=self.num_draft_tokens,
|
268
|
+
dtype=torch.int32,
|
269
|
+
device=self.device,
|
270
|
+
)
|
271
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
272
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
273
|
+
kv_indices = self.cuda_graph_kv_indices
|
274
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
275
|
+
self.req_to_token,
|
276
|
+
req_pool_indices,
|
277
|
+
seq_lens,
|
278
|
+
kv_indptr,
|
279
|
+
None,
|
280
|
+
kv_indices,
|
281
|
+
self.req_to_token.stride(0),
|
282
|
+
)
|
283
|
+
|
284
|
+
custom_mask = self.cuda_graph_custom_mask
|
285
|
+
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
286
|
+
mask_indptr = self.mask_indptr[: bs + 1]
|
287
|
+
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
288
|
+
max_extend_len = self.num_draft_tokens
|
289
|
+
attn_logits = None
|
290
|
+
else:
|
291
|
+
raise ValueError(
|
292
|
+
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
293
|
+
)
|
178
294
|
|
179
295
|
self.forward_metadata = (
|
180
|
-
|
181
|
-
|
296
|
+
attn_logits,
|
297
|
+
max_extend_len,
|
182
298
|
kv_indptr,
|
183
299
|
kv_indices,
|
184
|
-
|
185
|
-
|
186
|
-
|
300
|
+
qo_indptr,
|
301
|
+
custom_mask,
|
302
|
+
mask_indptr,
|
187
303
|
)
|
188
304
|
|
189
305
|
def init_forward_metadata_replay_cuda_graph(
|
@@ -197,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
|
|
197
313
|
spec_info: Optional[SpecInfo],
|
198
314
|
):
|
199
315
|
# NOTE: encoder_lens expected to be zeros or None
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
316
|
+
if forward_mode.is_decode_or_idle():
|
317
|
+
# Update kv_indptr, kv_indices
|
318
|
+
kv_indptr = self.kv_indptr
|
319
|
+
kv_indices = self.cuda_graph_kv_indices
|
320
|
+
if spec_info is None:
|
321
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
322
|
+
kv_indptr = kv_indptr[: bs + 1]
|
323
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
324
|
+
self.req_to_token,
|
325
|
+
req_pool_indices[:bs],
|
326
|
+
seq_lens[:bs],
|
327
|
+
kv_indptr,
|
328
|
+
None,
|
329
|
+
kv_indices,
|
330
|
+
self.req_to_token.stride(0),
|
331
|
+
)
|
332
|
+
else:
|
333
|
+
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
334
|
+
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
335
|
+
elif forward_mode.is_target_verify():
|
336
|
+
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
337
|
+
bs = len(req_pool_indices)
|
338
|
+
qo_indptr = self.qo_indptr[: bs + 1]
|
339
|
+
qo_indptr[: bs + 1] = torch.arange(
|
340
|
+
0,
|
341
|
+
(1 + bs) * self.num_draft_tokens,
|
342
|
+
step=self.num_draft_tokens,
|
343
|
+
dtype=torch.int32,
|
344
|
+
device=self.device,
|
345
|
+
)
|
346
|
+
kv_indptr = self.kv_indptr[: bs + 1]
|
347
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
348
|
+
kv_indices = self.cuda_graph_kv_indices
|
349
|
+
create_flashinfer_kv_indices_triton[(bs,)](
|
350
|
+
self.req_to_token,
|
351
|
+
req_pool_indices,
|
352
|
+
seq_lens,
|
353
|
+
kv_indptr,
|
354
|
+
None,
|
355
|
+
kv_indices,
|
356
|
+
self.req_to_token.stride(0),
|
357
|
+
)
|
358
|
+
custom_mask = self.cuda_graph_custom_mask
|
359
|
+
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
360
|
+
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
361
|
+
mask_indptr = self.mask_indptr[: bs + 1]
|
362
|
+
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
363
|
+
else:
|
364
|
+
raise ValueError(
|
365
|
+
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
366
|
+
)
|
216
367
|
|
217
368
|
def get_cuda_graph_seq_len_fill_value(self):
|
218
369
|
return 1
|
@@ -244,8 +395,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
244
395
|
kv_indices,
|
245
396
|
qo_indptr,
|
246
397
|
custom_mask,
|
247
|
-
|
398
|
+
mask_indptr,
|
248
399
|
) = self.forward_metadata
|
400
|
+
|
249
401
|
self.extend_attention_fwd(
|
250
402
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
251
403
|
k.contiguous(),
|
@@ -257,7 +409,7 @@ class TritonAttnBackend(AttentionBackend):
|
|
257
409
|
kv_indptr,
|
258
410
|
kv_indices,
|
259
411
|
custom_mask,
|
260
|
-
|
412
|
+
mask_indptr,
|
261
413
|
max_extend_len,
|
262
414
|
layer.scaling,
|
263
415
|
layer.logit_cap,
|
@@ -303,3 +455,137 @@ class TritonAttnBackend(AttentionBackend):
|
|
303
455
|
layer.logit_cap,
|
304
456
|
)
|
305
457
|
return o
|
458
|
+
|
459
|
+
|
460
|
+
class TritonMultiStepDraftBackend:
|
461
|
+
"""
|
462
|
+
Wrap multiple triton attention backends as one for multiple consecutive
|
463
|
+
draft decoding steps.
|
464
|
+
"""
|
465
|
+
|
466
|
+
def __init__(
|
467
|
+
self,
|
468
|
+
model_runner: ModelRunner,
|
469
|
+
topk: int,
|
470
|
+
speculative_num_steps: int,
|
471
|
+
):
|
472
|
+
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
473
|
+
|
474
|
+
self.topk = topk
|
475
|
+
self.speculative_num_steps = speculative_num_steps
|
476
|
+
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
477
|
+
max_bs = model_runner.req_to_token_pool.size
|
478
|
+
self.kv_indptr = torch.zeros(
|
479
|
+
(
|
480
|
+
self.speculative_num_steps,
|
481
|
+
max_bs + 1,
|
482
|
+
),
|
483
|
+
dtype=torch.int32,
|
484
|
+
device=model_runner.device,
|
485
|
+
)
|
486
|
+
self.attn_backends = []
|
487
|
+
for i in range(self.speculative_num_steps):
|
488
|
+
self.attn_backends.append(
|
489
|
+
TritonAttnBackend(
|
490
|
+
model_runner,
|
491
|
+
skip_prefill=True,
|
492
|
+
kv_indptr_buf=self.kv_indptr[i],
|
493
|
+
)
|
494
|
+
)
|
495
|
+
self.max_context_len = self.attn_backends[0].max_context_len
|
496
|
+
self.device = model_runner.device
|
497
|
+
# Cached variables for generate_draft_decode_kv_indices
|
498
|
+
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
499
|
+
|
500
|
+
def common_template(
|
501
|
+
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
|
502
|
+
):
|
503
|
+
num_seqs = forward_batch.batch_size
|
504
|
+
bs = self.topk * num_seqs
|
505
|
+
seq_lens_sum = forward_batch.seq_lens_sum
|
506
|
+
|
507
|
+
self.generate_draft_decode_kv_indices[
|
508
|
+
(self.speculative_num_steps, num_seqs, self.topk)
|
509
|
+
](
|
510
|
+
forward_batch.req_pool_indices,
|
511
|
+
forward_batch.req_to_token_pool.req_to_token,
|
512
|
+
forward_batch.seq_lens,
|
513
|
+
kv_indices_buffer,
|
514
|
+
self.kv_indptr,
|
515
|
+
forward_batch.positions,
|
516
|
+
num_seqs,
|
517
|
+
self.topk,
|
518
|
+
self.pool_len,
|
519
|
+
kv_indices_buffer.shape[1],
|
520
|
+
self.kv_indptr.shape[1],
|
521
|
+
triton.next_power_of_2(num_seqs),
|
522
|
+
triton.next_power_of_2(self.speculative_num_steps),
|
523
|
+
triton.next_power_of_2(bs),
|
524
|
+
)
|
525
|
+
|
526
|
+
for i in range(self.speculative_num_steps):
|
527
|
+
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
528
|
+
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
529
|
+
: seq_lens_sum * self.topk + bs * (i + 1)
|
530
|
+
]
|
531
|
+
call_fn(i, forward_batch)
|
532
|
+
|
533
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
534
|
+
kv_indices = torch.zeros(
|
535
|
+
(
|
536
|
+
self.speculative_num_steps,
|
537
|
+
forward_batch.batch_size * self.topk * self.max_context_len,
|
538
|
+
),
|
539
|
+
dtype=torch.int32,
|
540
|
+
device=self.device,
|
541
|
+
)
|
542
|
+
|
543
|
+
def call_fn(i, forward_batch):
|
544
|
+
forward_batch.spec_info.kv_indptr = (
|
545
|
+
forward_batch.spec_info.kv_indptr.clone()
|
546
|
+
)
|
547
|
+
forward_batch.spec_info.kv_indices = (
|
548
|
+
forward_batch.spec_info.kv_indices.clone()
|
549
|
+
)
|
550
|
+
self.attn_backends[i].init_forward_metadata(forward_batch)
|
551
|
+
|
552
|
+
self.common_template(forward_batch, kv_indices, call_fn)
|
553
|
+
|
554
|
+
def init_cuda_graph_state(self, max_bs: int):
|
555
|
+
self.cuda_graph_kv_indices = torch.zeros(
|
556
|
+
(self.speculative_num_steps, max_bs * self.max_context_len),
|
557
|
+
dtype=torch.int32,
|
558
|
+
device=self.device,
|
559
|
+
)
|
560
|
+
for i in range(self.speculative_num_steps):
|
561
|
+
self.attn_backends[i].init_cuda_graph_state(
|
562
|
+
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
563
|
+
)
|
564
|
+
|
565
|
+
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
566
|
+
def call_fn(i, forward_batch):
|
567
|
+
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
568
|
+
forward_batch.batch_size,
|
569
|
+
forward_batch.batch_size * self.topk,
|
570
|
+
forward_batch.req_pool_indices,
|
571
|
+
forward_batch.seq_lens,
|
572
|
+
encoder_lens=None,
|
573
|
+
forward_mode=ForwardMode.DECODE,
|
574
|
+
spec_info=forward_batch.spec_info,
|
575
|
+
)
|
576
|
+
|
577
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
578
|
+
|
579
|
+
def init_forward_metadata_replay_cuda_graph(self, forward_batch):
|
580
|
+
def call_fn(i, forward_batch):
|
581
|
+
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
582
|
+
forward_batch.batch_size,
|
583
|
+
forward_batch.req_pool_indices,
|
584
|
+
forward_batch.seq_lens,
|
585
|
+
seq_lens_sum=-1,
|
586
|
+
encoder_lens=None,
|
587
|
+
forward_mode=ForwardMode.DECODE,
|
588
|
+
spec_info=forward_batch.spec_info,
|
589
|
+
)
|
590
|
+
|
591
|
+
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
@@ -50,7 +50,7 @@ def _fwd_kernel(
|
|
50
50
|
kv_indptr,
|
51
51
|
kv_indices,
|
52
52
|
mask_ptr,
|
53
|
-
|
53
|
+
mask_indptr,
|
54
54
|
sm_scale,
|
55
55
|
kv_group_num,
|
56
56
|
stride_qbs,
|
@@ -87,7 +87,7 @@ def _fwd_kernel(
|
|
87
87
|
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
|
88
88
|
|
89
89
|
if USE_CUSTOM_MASK:
|
90
|
-
cur_seq_mask_start_idx = tl.load(
|
90
|
+
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
91
91
|
|
92
92
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
93
93
|
offs_dv = tl.arange(0, BLOCK_DV)
|
@@ -288,7 +288,7 @@ def extend_attention_fwd(
|
|
288
288
|
kv_indptr,
|
289
289
|
kv_indices,
|
290
290
|
custom_mask,
|
291
|
-
|
291
|
+
mask_indptr,
|
292
292
|
max_len_extend,
|
293
293
|
sm_scale=None,
|
294
294
|
logit_cap=0.0,
|
@@ -364,7 +364,7 @@ def extend_attention_fwd(
|
|
364
364
|
kv_indptr,
|
365
365
|
kv_indices,
|
366
366
|
custom_mask,
|
367
|
-
|
367
|
+
mask_indptr,
|
368
368
|
sm_scale,
|
369
369
|
kv_group_num,
|
370
370
|
q_extend.stride(0),
|
sglang/srt/layers/linear.py
CHANGED
@@ -421,11 +421,18 @@ class ColumnParallelLinear(LinearBase):
|
|
421
421
|
if len(loaded_weight.shape) == 0:
|
422
422
|
assert loaded_weight.numel() == 1
|
423
423
|
loaded_weight = loaded_weight.reshape(1)
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
424
|
+
|
425
|
+
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
426
|
+
|
427
|
+
if isinstance(param, _ColumnvLLMParameter):
|
428
|
+
# FIXME: why would we need this special case?
|
429
|
+
param.load_column_parallel_weight(
|
430
|
+
loaded_weight,
|
431
|
+
tp_rank=self.tp_rank,
|
432
|
+
use_presharded_weights=self.use_presharded_weights,
|
433
|
+
)
|
434
|
+
else:
|
435
|
+
param.load_column_parallel_weight(loaded_weight)
|
429
436
|
|
430
437
|
def forward(self, input_):
|
431
438
|
bias = self.bias if not self.skip_bias_add else None
|