sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +6 -25
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +104 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +58 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +117 -131
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +57 -44
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -1
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,235 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.jit
|
7
|
+
def create_flashinfer_kv_indices_triton(
|
8
|
+
req_to_token_ptr, # [max_batch, max_context_len]
|
9
|
+
req_pool_indices_ptr,
|
10
|
+
page_kernel_lens_ptr,
|
11
|
+
kv_indptr,
|
12
|
+
kv_start_idx,
|
13
|
+
kv_indices_ptr,
|
14
|
+
max_context_len: tl.constexpr,
|
15
|
+
):
|
16
|
+
BLOCK_SIZE: tl.constexpr = 512
|
17
|
+
pid = tl.program_id(axis=0)
|
18
|
+
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
19
|
+
kv_indices_offset = tl.load(kv_indptr + pid)
|
20
|
+
|
21
|
+
kv_start = 0
|
22
|
+
kv_end = 0
|
23
|
+
if kv_start_idx:
|
24
|
+
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
25
|
+
kv_end = kv_start
|
26
|
+
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
27
|
+
|
28
|
+
req_to_token_ptr += req_pool_index * max_context_len
|
29
|
+
kv_indices_ptr += kv_indices_offset
|
30
|
+
|
31
|
+
ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
|
32
|
+
st_offset = tl.arange(0, BLOCK_SIZE)
|
33
|
+
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
34
|
+
for _ in range(num_loop):
|
35
|
+
mask = ld_offset < kv_end
|
36
|
+
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
|
37
|
+
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
|
38
|
+
ld_offset += BLOCK_SIZE
|
39
|
+
st_offset += BLOCK_SIZE
|
40
|
+
|
41
|
+
|
42
|
+
class FlashinferUpdater:
|
43
|
+
def __init__(
|
44
|
+
self,
|
45
|
+
forward_mode,
|
46
|
+
model_runner,
|
47
|
+
req_pool_indices,
|
48
|
+
seq_lens,
|
49
|
+
prefix_lens,
|
50
|
+
decode_wrapper=None,
|
51
|
+
use_ragged=False,
|
52
|
+
):
|
53
|
+
self.forward_mode = forward_mode
|
54
|
+
self.model_runner = model_runner
|
55
|
+
self.req_pool_indices = req_pool_indices
|
56
|
+
self.seq_lens = seq_lens
|
57
|
+
self.prefix_lens = prefix_lens
|
58
|
+
self.use_ragged = use_ragged
|
59
|
+
|
60
|
+
self.num_qo_heads = (
|
61
|
+
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
62
|
+
)
|
63
|
+
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
64
|
+
model_runner.tp_size
|
65
|
+
)
|
66
|
+
self.head_dim = model_runner.model_config.head_dim
|
67
|
+
self.batch_size = len(req_pool_indices)
|
68
|
+
|
69
|
+
self.decode_wrapper = (
|
70
|
+
decode_wrapper or self.model_runner.attn_backend.decode_wrapper
|
71
|
+
)
|
72
|
+
self.prefill_wrapper_ragged = (
|
73
|
+
self.model_runner.attn_backend.prefill_wrapper_ragged
|
74
|
+
)
|
75
|
+
self.prefill_wrapper_paged = (
|
76
|
+
self.model_runner.attn_backend.prefill_wrapper_paged
|
77
|
+
)
|
78
|
+
|
79
|
+
self.kv_last_page_len = torch.ones(
|
80
|
+
(self.batch_size,), dtype=torch.int32, device="cuda"
|
81
|
+
)
|
82
|
+
|
83
|
+
def _init_indices_no_sliding_window(self):
|
84
|
+
if self.use_ragged:
|
85
|
+
paged_kernel_lens = self.prefix_lens
|
86
|
+
else:
|
87
|
+
paged_kernel_lens = self.seq_lens
|
88
|
+
|
89
|
+
self.kv_indptr = torch.zeros(
|
90
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
91
|
+
)
|
92
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
93
|
+
self.kv_indices = torch.empty(
|
94
|
+
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
95
|
+
)
|
96
|
+
|
97
|
+
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
98
|
+
self.model_runner.req_to_token_pool.req_to_token,
|
99
|
+
self.req_pool_indices,
|
100
|
+
paged_kernel_lens,
|
101
|
+
self.kv_indptr,
|
102
|
+
None,
|
103
|
+
self.kv_indices,
|
104
|
+
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
105
|
+
)
|
106
|
+
|
107
|
+
def _init_indices_sliding_window(self, wrapper_id):
|
108
|
+
if wrapper_id == 0:
|
109
|
+
# window attention use paged only
|
110
|
+
if self.forward_mode.is_decode():
|
111
|
+
paged_kernel_lens = torch.minimum(
|
112
|
+
self.seq_lens,
|
113
|
+
torch.tensor(self.model_runner.sliding_window_size + 1),
|
114
|
+
)
|
115
|
+
else:
|
116
|
+
paged_kernel_lens = torch.minimum(
|
117
|
+
self.seq_lens,
|
118
|
+
torch.tensor(self.model_runner.sliding_window_size)
|
119
|
+
+ self.seq_lens
|
120
|
+
- self.prefix_lens,
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
# full attention
|
124
|
+
paged_kernel_lens = self.seq_lens
|
125
|
+
|
126
|
+
kv_start_idx = self.seq_lens - paged_kernel_lens
|
127
|
+
self.kv_indptr = torch.zeros(
|
128
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
129
|
+
)
|
130
|
+
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
131
|
+
self.kv_indices = torch.empty(
|
132
|
+
self.kv_indptr[-1], dtype=torch.int32, device="cuda"
|
133
|
+
)
|
134
|
+
create_flashinfer_kv_indices_triton[(self.batch_size,)](
|
135
|
+
self.model_runner.req_to_token_pool.req_to_token,
|
136
|
+
self.req_pool_indices,
|
137
|
+
paged_kernel_lens,
|
138
|
+
self.kv_indptr,
|
139
|
+
kv_start_idx,
|
140
|
+
self.kv_indices,
|
141
|
+
self.model_runner.req_to_token_pool.req_to_token.size(1),
|
142
|
+
)
|
143
|
+
|
144
|
+
def _update_decode_indices(self, decode_wrapper):
|
145
|
+
decode_wrapper.end_forward()
|
146
|
+
decode_wrapper.begin_forward(
|
147
|
+
self.kv_indptr,
|
148
|
+
self.kv_indices,
|
149
|
+
self.kv_last_page_len,
|
150
|
+
self.num_qo_heads,
|
151
|
+
self.num_kv_heads,
|
152
|
+
self.head_dim,
|
153
|
+
1,
|
154
|
+
data_type=self.model_runner.kv_cache_dtype,
|
155
|
+
q_data_type=self.model_runner.dtype,
|
156
|
+
)
|
157
|
+
|
158
|
+
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
159
|
+
# extend part
|
160
|
+
qo_indptr = torch.zeros(
|
161
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
162
|
+
)
|
163
|
+
qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0)
|
164
|
+
|
165
|
+
if self.use_ragged:
|
166
|
+
ragged_wrapper.end_forward()
|
167
|
+
ragged_wrapper.begin_forward(
|
168
|
+
qo_indptr,
|
169
|
+
qo_indptr,
|
170
|
+
self.num_qo_heads,
|
171
|
+
self.num_kv_heads,
|
172
|
+
self.head_dim,
|
173
|
+
)
|
174
|
+
|
175
|
+
# cached part
|
176
|
+
paged_wrapper.end_forward()
|
177
|
+
paged_wrapper.begin_forward(
|
178
|
+
qo_indptr,
|
179
|
+
self.kv_indptr,
|
180
|
+
self.kv_indices,
|
181
|
+
self.kv_last_page_len,
|
182
|
+
self.num_qo_heads,
|
183
|
+
self.num_kv_heads,
|
184
|
+
self.head_dim,
|
185
|
+
1,
|
186
|
+
)
|
187
|
+
|
188
|
+
def update_indices_no_sliding_window(self):
|
189
|
+
self._init_indices_no_sliding_window()
|
190
|
+
|
191
|
+
if self.forward_mode.is_decode():
|
192
|
+
self._update_decode_indices(self.decode_wrapper)
|
193
|
+
else:
|
194
|
+
self._update_extend_indices(
|
195
|
+
self.prefill_wrapper_ragged,
|
196
|
+
self.prefill_wrapper_paged,
|
197
|
+
)
|
198
|
+
|
199
|
+
def update_indices_sliding_window(self):
|
200
|
+
assert self.use_ragged is False
|
201
|
+
|
202
|
+
for wrapper_id in range(2):
|
203
|
+
self._init_indices_sliding_window(wrapper_id)
|
204
|
+
if self.forward_mode.is_decode():
|
205
|
+
self._update_decode_indices(self.decode_wrapper[wrapper_id])
|
206
|
+
else:
|
207
|
+
self._update_extend_indices(
|
208
|
+
None,
|
209
|
+
self.prefill_wrapper_paged[wrapper_id],
|
210
|
+
)
|
211
|
+
|
212
|
+
|
213
|
+
def update_flashinfer_indices(
|
214
|
+
forward_mode,
|
215
|
+
model_runner,
|
216
|
+
req_pool_indices,
|
217
|
+
seq_lens,
|
218
|
+
prefix_lens,
|
219
|
+
decode_wrapper=None,
|
220
|
+
use_ragged=False,
|
221
|
+
):
|
222
|
+
updater = FlashinferUpdater(
|
223
|
+
forward_mode,
|
224
|
+
model_runner,
|
225
|
+
req_pool_indices,
|
226
|
+
seq_lens,
|
227
|
+
prefix_lens,
|
228
|
+
decode_wrapper,
|
229
|
+
use_ragged,
|
230
|
+
)
|
231
|
+
|
232
|
+
if model_runner.sliding_window_size is None:
|
233
|
+
updater.update_indices_no_sliding_window()
|
234
|
+
else:
|
235
|
+
updater.update_indices_sliding_window()
|
@@ -37,7 +37,7 @@ class LogitsProcessorOutput:
|
|
37
37
|
|
38
38
|
# The normlaized logprobs of prompts. shape: [#seq]
|
39
39
|
normalized_prompt_logprobs: torch.Tensor
|
40
|
-
# The logprobs of input tokens.
|
40
|
+
# The logprobs of input tokens. shape: [#token, vocab_size]
|
41
41
|
input_token_logprobs: torch.Tensor
|
42
42
|
|
43
43
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
@@ -49,25 +49,39 @@ class LogitsProcessorOutput:
|
|
49
49
|
@dataclasses.dataclass
|
50
50
|
class LogitsMetadata:
|
51
51
|
forward_mode: ForwardMode
|
52
|
+
top_logprobs_nums: Optional[List[int]]
|
53
|
+
|
52
54
|
return_logprob: bool = False
|
55
|
+
return_top_logprob: bool = False
|
53
56
|
|
54
57
|
extend_seq_lens: Optional[torch.Tensor] = None
|
55
|
-
|
56
|
-
top_logprobs_nums: Optional[List[int]] = None
|
58
|
+
extend_seq_lens_cpu: Optional[List[int]] = None
|
57
59
|
|
58
|
-
|
59
|
-
|
60
|
+
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
61
|
+
extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
|
60
62
|
|
61
63
|
@classmethod
|
62
64
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
65
|
+
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
66
|
+
if input_metadata.forward_mode.is_extend():
|
67
|
+
extend_logprob_pruned_lens_cpu = [
|
68
|
+
extend_len - start_len
|
69
|
+
for extend_len, start_len in zip(
|
70
|
+
input_metadata.extend_seq_lens,
|
71
|
+
input_metadata.extend_logprob_start_lens_cpu,
|
72
|
+
)
|
73
|
+
]
|
74
|
+
else:
|
75
|
+
extend_logprob_pruned_lens_cpu = None
|
63
76
|
return cls(
|
64
77
|
forward_mode=input_metadata.forward_mode,
|
65
|
-
extend_seq_lens=input_metadata.extend_seq_lens,
|
66
|
-
extend_start_loc=input_metadata.extend_start_loc,
|
67
|
-
return_logprob=input_metadata.return_logprob,
|
68
78
|
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
79
|
+
return_logprob=input_metadata.return_logprob,
|
80
|
+
return_top_logprob=return_top_logprob,
|
81
|
+
extend_seq_lens=input_metadata.extend_seq_lens,
|
69
82
|
extend_seq_lens_cpu=input_metadata.extend_seq_lens_cpu,
|
70
|
-
|
83
|
+
extend_logprob_start_lens_cpu=input_metadata.extend_logprob_start_lens_cpu,
|
84
|
+
extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
|
71
85
|
)
|
72
86
|
|
73
87
|
|
@@ -82,57 +96,49 @@ class LogitsProcessor(nn.Module):
|
|
82
96
|
def _get_normalized_prompt_logprobs(
|
83
97
|
self,
|
84
98
|
input_token_logprobs: torch.Tensor,
|
85
|
-
cum_start_len0: torch.Tensor,
|
86
|
-
cum_start_len1: torch.Tensor,
|
87
99
|
logits_metadata: LogitsMetadata,
|
88
100
|
):
|
89
101
|
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
102
|
+
pruned_lens = torch.tensor(
|
103
|
+
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
104
|
+
)
|
90
105
|
|
91
|
-
start =
|
92
|
-
|
93
|
-
|
94
|
-
|
106
|
+
start = torch.zeros_like(pruned_lens)
|
107
|
+
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
108
|
+
end = torch.clamp(
|
109
|
+
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
110
|
+
)
|
95
111
|
sum_logp = (
|
96
112
|
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
97
113
|
)
|
98
|
-
normalized_prompt_logprobs = sum_logp / (
|
99
|
-
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
100
|
-
)
|
101
|
-
|
114
|
+
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
102
115
|
return normalized_prompt_logprobs
|
103
116
|
|
104
117
|
@staticmethod
|
105
118
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
106
|
-
|
119
|
+
max_k = max(logits_metadata.top_logprobs_nums)
|
120
|
+
ret = all_logprobs.topk(max_k, dim=1)
|
121
|
+
values = ret.values.tolist()
|
122
|
+
indices = ret.indices.tolist()
|
123
|
+
|
124
|
+
if logits_metadata.forward_mode.is_decode():
|
107
125
|
output_top_logprobs = []
|
108
|
-
max_k = max(logits_metadata.top_logprobs_nums)
|
109
|
-
ret = all_logprobs.topk(max_k, dim=1)
|
110
|
-
values = ret.values.tolist()
|
111
|
-
indices = ret.indices.tolist()
|
112
126
|
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
113
127
|
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
114
128
|
return None, output_top_logprobs
|
115
129
|
else:
|
116
|
-
# TODO: vectorize the code below
|
117
130
|
input_top_logprobs, output_top_logprobs = [], []
|
118
|
-
pt = 0
|
119
|
-
extend_seq_lens_cpu = logits_metadata.extend_seq_lens_cpu
|
120
131
|
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
start_len = logits_metadata.logprob_start_lens_cpu[i]
|
128
|
-
pruned_len = extend_seq_len - start_len
|
129
|
-
|
130
|
-
if extend_seq_len == 0:
|
132
|
+
pt = 0
|
133
|
+
for k, pruned_len in zip(
|
134
|
+
logits_metadata.top_logprobs_nums,
|
135
|
+
logits_metadata.extend_logprob_pruned_lens_cpu,
|
136
|
+
):
|
137
|
+
if pruned_len <= 0:
|
131
138
|
input_top_logprobs.append([])
|
132
139
|
output_top_logprobs.append([])
|
133
140
|
continue
|
134
141
|
|
135
|
-
k = logits_metadata.top_logprobs_nums[i]
|
136
142
|
input_top_logprobs.append(
|
137
143
|
[
|
138
144
|
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
@@ -163,14 +169,11 @@ class LogitsProcessor(nn.Module):
|
|
163
169
|
assert isinstance(logits_metadata, LogitsMetadata)
|
164
170
|
|
165
171
|
# Get the last hidden states and last logits for the next token prediction
|
166
|
-
if logits_metadata.forward_mode
|
172
|
+
if logits_metadata.forward_mode.is_decode():
|
167
173
|
last_index = None
|
168
174
|
last_hidden = hidden_states
|
169
175
|
else:
|
170
|
-
last_index = (
|
171
|
-
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
172
|
-
- 1
|
173
|
-
)
|
176
|
+
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
174
177
|
last_hidden = hidden_states[last_index]
|
175
178
|
|
176
179
|
last_logits = torch.matmul(last_hidden, weight.T)
|
@@ -194,21 +197,15 @@ class LogitsProcessor(nn.Module):
|
|
194
197
|
output_top_logprobs=None,
|
195
198
|
)
|
196
199
|
else:
|
197
|
-
|
198
|
-
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
199
|
-
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
200
|
+
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
200
201
|
|
201
|
-
|
202
|
-
return_top_logprob
|
203
|
-
x > 0 for x in logits_metadata.top_logprobs_nums
|
204
|
-
)
|
205
|
-
if return_top_logprob:
|
202
|
+
if logits_metadata.forward_mode.is_decode():
|
203
|
+
if logits_metadata.return_top_logprob:
|
206
204
|
output_top_logprobs = self.get_top_logprobs(
|
207
205
|
last_logprobs, logits_metadata
|
208
206
|
)[1]
|
209
207
|
else:
|
210
208
|
output_top_logprobs = None
|
211
|
-
|
212
209
|
return LogitsProcessorOutput(
|
213
210
|
next_token_logits=last_logits,
|
214
211
|
next_token_logprobs=last_logprobs,
|
@@ -218,22 +215,18 @@ class LogitsProcessor(nn.Module):
|
|
218
215
|
output_top_logprobs=output_top_logprobs,
|
219
216
|
)
|
220
217
|
else:
|
218
|
+
# Slice the requested tokens to compute logprob
|
221
219
|
pt, states, pruned_input_ids = 0, [], []
|
222
|
-
for
|
223
|
-
|
220
|
+
for start_len, extend_len in zip(
|
221
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
222
|
+
logits_metadata.extend_seq_lens_cpu,
|
223
|
+
):
|
224
224
|
states.append(hidden_states[pt + start_len : pt + extend_len])
|
225
225
|
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
226
226
|
pt += extend_len
|
227
227
|
|
228
|
+
# Compute the logits and logprobs for all required tokens
|
228
229
|
states = torch.cat(states, dim=0)
|
229
|
-
pruned_input_ids = torch.cat(pruned_input_ids, dim=0)
|
230
|
-
|
231
|
-
cum_start_len1 = torch.tensor(
|
232
|
-
logits_metadata.logprob_start_lens_cpu, device="cuda"
|
233
|
-
).cumsum(0)
|
234
|
-
cum_start_len0 = torch.zeros_like(cum_start_len1)
|
235
|
-
cum_start_len0[1:] = cum_start_len1[:-1]
|
236
|
-
|
237
230
|
all_logits = torch.matmul(states, weight.T)
|
238
231
|
if self.do_tensor_parallel_all_gather:
|
239
232
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
@@ -249,35 +242,29 @@ class LogitsProcessor(nn.Module):
|
|
249
242
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
250
243
|
|
251
244
|
# Get the logprob of top-k tokens
|
252
|
-
return_top_logprob
|
253
|
-
x > 0 for x in logits_metadata.top_logprobs_nums
|
254
|
-
)
|
255
|
-
if return_top_logprob:
|
245
|
+
if logits_metadata.return_top_logprob:
|
256
246
|
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
257
247
|
all_logprobs, logits_metadata
|
258
248
|
)
|
259
249
|
else:
|
260
250
|
input_top_logprobs = output_top_logprobs = None
|
261
251
|
|
262
|
-
|
263
|
-
|
264
|
-
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
265
|
-
# Note that we pad a zero at the end of each sequence for easy computation.
|
252
|
+
# Compute the normalized logprobs for the requested tokens.
|
253
|
+
# Note that we pad a zero at the end for easy batching.
|
266
254
|
input_token_logprobs = all_logprobs[
|
267
255
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
268
|
-
torch.cat(
|
256
|
+
torch.cat(
|
257
|
+
[
|
258
|
+
torch.cat(pruned_input_ids)[1:],
|
259
|
+
torch.tensor([0], device="cuda"),
|
260
|
+
]
|
261
|
+
),
|
269
262
|
]
|
270
|
-
|
271
263
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
272
264
|
input_token_logprobs,
|
273
|
-
cum_start_len0,
|
274
|
-
cum_start_len1,
|
275
265
|
logits_metadata,
|
276
266
|
)
|
277
267
|
|
278
|
-
# Remove the last token logprob for the prefill tokens.
|
279
|
-
input_token_logprobs = input_token_logprobs[:-1]
|
280
|
-
|
281
268
|
return LogitsProcessorOutput(
|
282
269
|
next_token_logits=last_logits,
|
283
270
|
next_token_logprobs=last_logprobs,
|
@@ -15,20 +15,16 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Radix attention."""
|
17
17
|
|
18
|
-
from typing import Optional
|
19
|
-
|
20
|
-
import torch
|
21
|
-
from flashinfer.cascade import merge_state
|
22
18
|
from torch import nn
|
23
19
|
|
24
|
-
from sglang.
|
25
|
-
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
26
|
-
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
27
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
28
|
-
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
20
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
21
|
|
30
22
|
|
31
23
|
class RadixAttention(nn.Module):
|
24
|
+
"""
|
25
|
+
The attention layer implementation.
|
26
|
+
"""
|
27
|
+
|
32
28
|
def __init__(
|
33
29
|
self,
|
34
30
|
num_heads: int,
|
@@ -36,8 +32,8 @@ class RadixAttention(nn.Module):
|
|
36
32
|
scaling: float,
|
37
33
|
num_kv_heads: int,
|
38
34
|
layer_id: int,
|
39
|
-
sliding_window_size:
|
40
|
-
logit_cap:
|
35
|
+
sliding_window_size: int = -1,
|
36
|
+
logit_cap: float = 0.0,
|
41
37
|
v_head_dim: int = -1,
|
42
38
|
):
|
43
39
|
super().__init__()
|
@@ -49,160 +45,14 @@ class RadixAttention(nn.Module):
|
|
49
45
|
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
50
46
|
self.scaling = scaling
|
51
47
|
self.layer_id = layer_id
|
52
|
-
self.
|
53
|
-
|
54
|
-
if (
|
55
|
-
not global_server_args_dict.get("disable_flashinfer", False)
|
56
|
-
and self.qk_head_dim == self.v_head_dim
|
57
|
-
):
|
58
|
-
self.extend_forward = self.extend_forward_flashinfer
|
59
|
-
self.decode_forward = self.decode_forward_flashinfer
|
60
|
-
else:
|
61
|
-
self.extend_forward = self.extend_forward_triton
|
62
|
-
self.decode_forward = self.decode_forward_triton
|
63
|
-
|
64
|
-
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
65
|
-
|
66
|
-
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
67
|
-
if self.qk_head_dim != self.v_head_dim:
|
68
|
-
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
69
|
-
else:
|
70
|
-
o = torch.empty_like(q)
|
71
|
-
|
72
|
-
self.store_kv_cache(k, v, input_metadata)
|
73
|
-
extend_attention_fwd(
|
74
|
-
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
75
|
-
k.contiguous(),
|
76
|
-
v.contiguous(),
|
77
|
-
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
78
|
-
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
79
|
-
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
80
|
-
input_metadata.req_to_token_pool.req_to_token,
|
81
|
-
input_metadata.req_pool_indices,
|
82
|
-
input_metadata.triton_start_loc,
|
83
|
-
input_metadata.seq_lens,
|
84
|
-
input_metadata.triton_prefix_lens,
|
85
|
-
input_metadata.extend_start_loc,
|
86
|
-
input_metadata.extend_seq_lens,
|
87
|
-
input_metadata.triton_max_seq_len,
|
88
|
-
input_metadata.triton_max_extend_len,
|
89
|
-
sm_scale=self.scaling,
|
90
|
-
logit_cap=self.logit_cap,
|
91
|
-
)
|
92
|
-
|
93
|
-
return o
|
94
|
-
|
95
|
-
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
96
|
-
if self.qk_head_dim != self.v_head_dim:
|
97
|
-
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
98
|
-
else:
|
99
|
-
o = torch.empty_like(q)
|
100
|
-
self.store_kv_cache(k, v, input_metadata)
|
101
|
-
|
102
|
-
decode_attention_fwd(
|
103
|
-
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
104
|
-
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
105
|
-
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
106
|
-
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
107
|
-
input_metadata.req_to_token_pool.req_to_token,
|
108
|
-
input_metadata.req_pool_indices,
|
109
|
-
input_metadata.triton_start_loc,
|
110
|
-
input_metadata.seq_lens,
|
111
|
-
input_metadata.triton_max_seq_len,
|
112
|
-
input_metadata.total_num_tokens,
|
113
|
-
sm_scale=self.scaling,
|
114
|
-
logit_cap=self.logit_cap,
|
115
|
-
)
|
116
|
-
|
117
|
-
return o
|
118
|
-
|
119
|
-
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
120
|
-
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
121
|
-
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
122
|
-
if self.sliding_window_size != -1:
|
123
|
-
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
124
|
-
else:
|
125
|
-
if isinstance(prefill_wrapper_paged, list):
|
126
|
-
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
127
|
-
|
128
|
-
if not input_metadata.flashinfer_use_ragged:
|
129
|
-
if k is not None:
|
130
|
-
assert v is not None
|
131
|
-
self.store_kv_cache(k, v, input_metadata)
|
132
|
-
|
133
|
-
o = prefill_wrapper_paged.forward(
|
134
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
135
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
136
|
-
causal=True,
|
137
|
-
sm_scale=self.scaling,
|
138
|
-
window_left=self.sliding_window_size,
|
139
|
-
logits_soft_cap=self.logit_cap,
|
140
|
-
)
|
141
|
-
else:
|
142
|
-
o1, s1 = (
|
143
|
-
input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
144
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
145
|
-
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
146
|
-
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
147
|
-
causal=True,
|
148
|
-
sm_scale=self.scaling,
|
149
|
-
logits_soft_cap=self.logit_cap,
|
150
|
-
)
|
151
|
-
)
|
152
|
-
|
153
|
-
if input_metadata.extend_no_prefix:
|
154
|
-
o = o1
|
155
|
-
else:
|
156
|
-
o2, s2 = prefill_wrapper_paged.forward_return_lse(
|
157
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
158
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
159
|
-
causal=False,
|
160
|
-
sm_scale=self.scaling,
|
161
|
-
logits_soft_cap=self.logit_cap,
|
162
|
-
)
|
163
|
-
|
164
|
-
o, _ = merge_state(o1, s1, o2, s2)
|
165
|
-
|
166
|
-
self.store_kv_cache(k, v, input_metadata)
|
167
|
-
|
168
|
-
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
169
|
-
torch.cuda.synchronize()
|
170
|
-
|
171
|
-
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
172
|
-
|
173
|
-
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
174
|
-
decode_wrapper = input_metadata.flashinfer_decode_wrapper
|
175
|
-
if self.sliding_window_size != -1:
|
176
|
-
decode_wrapper = decode_wrapper[0]
|
177
|
-
else:
|
178
|
-
if isinstance(decode_wrapper, list):
|
179
|
-
decode_wrapper = decode_wrapper[1]
|
180
|
-
|
181
|
-
if k is not None:
|
182
|
-
assert v is not None
|
183
|
-
self.store_kv_cache(k, v, input_metadata)
|
184
|
-
|
185
|
-
o = decode_wrapper.forward(
|
186
|
-
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
187
|
-
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
188
|
-
sm_scale=self.scaling,
|
189
|
-
logits_soft_cap=self.logit_cap,
|
190
|
-
)
|
191
|
-
|
192
|
-
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
48
|
+
self.logit_cap = logit_cap
|
49
|
+
self.sliding_window_size = sliding_window_size or -1
|
193
50
|
|
194
51
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
195
52
|
if k is not None:
|
53
|
+
# For cross-layer sharing, kv can be None
|
196
54
|
assert v is not None
|
197
55
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
198
56
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
199
57
|
|
200
|
-
|
201
|
-
return self.extend_forward(q, k, v, input_metadata)
|
202
|
-
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
203
|
-
return self.decode_forward(q, k, v, input_metadata)
|
204
|
-
|
205
|
-
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
206
|
-
input_metadata.token_to_kv_pool.set_kv_buffer(
|
207
|
-
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
208
|
-
)
|
58
|
+
return input_metadata.attn_backend.forward(q, k, v, self, input_metadata)
|