sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,56 @@
|
|
1
|
+
"""Logits processing."""
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import List, Union
|
5
|
+
|
1
6
|
import torch
|
2
|
-
from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
|
3
7
|
from torch import nn
|
4
|
-
from vllm.
|
8
|
+
from vllm.distributed import (
|
5
9
|
get_tensor_model_parallel_world_size,
|
6
10
|
tensor_model_parallel_all_gather,
|
7
11
|
)
|
8
12
|
|
13
|
+
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
14
|
+
|
15
|
+
|
16
|
+
@dataclasses.dataclass
|
17
|
+
class LogitProcessorOutput:
|
18
|
+
# The logits of the next tokens. shape: [#seq, vocab_size]
|
19
|
+
next_token_logits: torch.Tensor
|
20
|
+
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
21
|
+
next_token_logprobs: torch.Tensor
|
22
|
+
|
23
|
+
# The normlaized logprobs of prompts. shape: [#seq]
|
24
|
+
normalized_prompt_logprobs: torch.Tensor
|
25
|
+
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
26
|
+
prefill_token_logprobs: torch.Tensor
|
27
|
+
|
28
|
+
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
29
|
+
prefill_top_logprobs: List
|
30
|
+
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
31
|
+
decode_top_logprobs: List
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class LogitsMetadata:
|
36
|
+
forward_mode: ForwardMode
|
37
|
+
extend_seq_lens: torch.Tensor
|
38
|
+
extend_start_loc: torch.Tensor
|
39
|
+
|
40
|
+
# For logprobs
|
41
|
+
return_logprob: bool
|
42
|
+
top_logprobs_nums: List[int]
|
43
|
+
|
44
|
+
@classmethod
|
45
|
+
def from_input_metadata(cls, input_metadata: InputMetadata):
|
46
|
+
return cls(
|
47
|
+
forward_mode=input_metadata.forward_mode,
|
48
|
+
extend_seq_lens=input_metadata.extend_seq_lens,
|
49
|
+
extend_start_loc=input_metadata.extend_start_loc,
|
50
|
+
return_logprob=input_metadata.return_logprob,
|
51
|
+
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
52
|
+
)
|
53
|
+
|
9
54
|
|
10
55
|
class LogitsProcessor(nn.Module):
|
11
56
|
def __init__(self, config):
|
@@ -13,78 +58,159 @@ class LogitsProcessor(nn.Module):
|
|
13
58
|
self.config = config
|
14
59
|
self.tp_size = get_tensor_model_parallel_world_size()
|
15
60
|
|
16
|
-
def
|
17
|
-
|
61
|
+
def _get_normalized_prompt_logprobs(
|
62
|
+
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
63
|
+
):
|
64
|
+
logprobs_cumsum = torch.cumsum(
|
65
|
+
prefill_token_logprobs, dim=0, dtype=torch.float32
|
66
|
+
)
|
18
67
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
68
|
+
start = logits_metadata.extend_start_loc.clone()
|
69
|
+
end = start + logits_metadata.extend_seq_lens - 2
|
70
|
+
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
71
|
+
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
72
|
+
sum_logp = (
|
73
|
+
logprobs_cumsum[end]
|
74
|
+
- logprobs_cumsum[start]
|
75
|
+
+ prefill_token_logprobs[start]
|
76
|
+
)
|
77
|
+
normalized_prompt_logprobs = sum_logp / (
|
78
|
+
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
79
|
+
)
|
80
|
+
|
81
|
+
return normalized_prompt_logprobs
|
82
|
+
|
83
|
+
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
84
|
+
# TODO: vectorize the code below
|
85
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
86
|
+
decode_top_logprobs = []
|
87
|
+
for i in range(all_logprobs.shape[0]):
|
88
|
+
k = logits_metadata.top_logprobs_nums[i]
|
89
|
+
t = all_logprobs[i].topk(k)
|
90
|
+
v_cpu = t.values.tolist()
|
91
|
+
p_cpu = t.indices.tolist()
|
92
|
+
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
93
|
+
return None, decode_top_logprobs
|
94
|
+
else:
|
95
|
+
prefill_top_logprobs, decode_top_logprobs = [], []
|
96
|
+
pt = 0
|
97
|
+
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
98
|
+
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
99
|
+
if extend_seq_len == 0:
|
100
|
+
prefill_top_logprobs.append([])
|
101
|
+
decode_top_logprobs.append([])
|
102
|
+
continue
|
103
|
+
k = logits_metadata.top_logprobs_nums[i]
|
104
|
+
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
105
|
+
vs_cpu = t.values.tolist()
|
106
|
+
ps_cpu = t.indices.tolist()
|
107
|
+
prefill_top_logprobs.append(
|
108
|
+
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
27
109
|
)
|
110
|
+
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
111
|
+
pt += extend_seq_len
|
112
|
+
|
113
|
+
return prefill_top_logprobs, decode_top_logprobs
|
114
|
+
|
115
|
+
def forward(
|
116
|
+
self,
|
117
|
+
input_ids,
|
118
|
+
hidden_states,
|
119
|
+
weight,
|
120
|
+
logits_metadata: Union[LogitsMetadata, InputMetadata],
|
121
|
+
):
|
122
|
+
if isinstance(logits_metadata, InputMetadata):
|
123
|
+
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
|
124
|
+
assert isinstance(logits_metadata, LogitsMetadata)
|
125
|
+
|
126
|
+
# Get the last hidden states and last logits for the next token prediction
|
127
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
128
|
+
last_index = None
|
129
|
+
last_hidden = hidden_states
|
130
|
+
else:
|
131
|
+
last_index = (
|
132
|
+
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
28
133
|
- 1
|
29
134
|
)
|
135
|
+
last_hidden = hidden_states[last_index]
|
30
136
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
last_logits
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
return
|
137
|
+
last_logits = torch.matmul(last_hidden, weight.T)
|
138
|
+
if self.tp_size > 1:
|
139
|
+
last_logits = tensor_model_parallel_all_gather(last_logits)
|
140
|
+
last_logits = last_logits[:, : self.config.vocab_size]
|
141
|
+
|
142
|
+
if hasattr(self.config, "final_logit_softcapping"):
|
143
|
+
last_logits /= self.config.final_logit_softcapping
|
144
|
+
last_logits = torch.tanh(last_logits)
|
145
|
+
last_logits *= self.config.final_logit_softcapping
|
146
|
+
|
147
|
+
# Return only last_logits if logprob is not requested
|
148
|
+
if not logits_metadata.return_logprob:
|
149
|
+
return LogitProcessorOutput(
|
150
|
+
next_token_logits=last_logits,
|
151
|
+
next_token_logprobs=None,
|
152
|
+
normalized_prompt_logprobs=None,
|
153
|
+
prefill_token_logprobs=None,
|
154
|
+
prefill_top_logprobs=None,
|
155
|
+
decode_top_logprobs=None,
|
156
|
+
)
|
44
157
|
else:
|
45
158
|
# When logprob is requested, compute the logits for all tokens.
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
159
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
160
|
+
all_logits = last_logits
|
161
|
+
else:
|
162
|
+
all_logits = torch.matmul(hidden_states, weight.T)
|
163
|
+
if self.tp_size > 1:
|
164
|
+
all_logits = tensor_model_parallel_all_gather(all_logits)
|
165
|
+
all_logits = all_logits[:, : self.config.vocab_size]
|
166
|
+
|
167
|
+
all_logprobs = all_logits.float()
|
168
|
+
del all_logits
|
169
|
+
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
170
|
+
|
171
|
+
# Get the logprob of top-k tokens
|
172
|
+
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
173
|
+
if return_top_logprob:
|
174
|
+
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
175
|
+
all_logprobs, logits_metadata
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
prefill_top_logprobs = decode_top_logprobs = None
|
179
|
+
|
180
|
+
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
181
|
+
return LogitProcessorOutput(
|
182
|
+
next_token_logits=last_logits,
|
183
|
+
next_token_logprobs=all_logprobs,
|
184
|
+
normalized_prompt_logprobs=None,
|
185
|
+
prefill_token_logprobs=None,
|
186
|
+
prefill_top_logprobs=None,
|
187
|
+
decode_top_logprobs=decode_top_logprobs,
|
188
|
+
)
|
56
189
|
else:
|
57
|
-
# Compute the logprobs for the last token of each request.
|
58
|
-
last_logits = logits[last_index]
|
59
190
|
last_logprobs = all_logprobs[last_index]
|
60
191
|
|
61
192
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
62
193
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
63
|
-
|
194
|
+
prefill_token_logprobs = all_logprobs[
|
64
195
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
65
196
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
66
197
|
]
|
67
|
-
logprobs_cumsum = torch.cumsum(
|
68
|
-
prefill_logprobs, dim=0, dtype=torch.float32
|
69
|
-
)
|
70
198
|
|
71
|
-
|
72
|
-
|
73
|
-
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
74
|
-
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
75
|
-
sum_logp = (
|
76
|
-
logprobs_cumsum[end]
|
77
|
-
- logprobs_cumsum[start]
|
78
|
-
+ prefill_logprobs[start]
|
79
|
-
)
|
80
|
-
normalized_logprobs = sum_logp / (
|
81
|
-
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
199
|
+
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
200
|
+
prefill_token_logprobs, logits_metadata
|
82
201
|
)
|
83
202
|
|
84
|
-
|
203
|
+
return LogitProcessorOutput(
|
204
|
+
next_token_logits=last_logits,
|
205
|
+
next_token_logprobs=last_logprobs,
|
206
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
207
|
+
prefill_token_logprobs=prefill_token_logprobs,
|
208
|
+
prefill_top_logprobs=prefill_top_logprobs,
|
209
|
+
decode_top_logprobs=decode_top_logprobs,
|
210
|
+
)
|
85
211
|
|
86
212
|
|
87
|
-
|
213
|
+
def test():
|
88
214
|
all_logprobs = torch.tensor(
|
89
215
|
# s s s
|
90
216
|
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
@@ -93,23 +219,26 @@ if __name__ == "__main__":
|
|
93
219
|
)
|
94
220
|
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
95
221
|
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
96
|
-
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
97
222
|
|
98
|
-
|
223
|
+
token_logprobs = all_logprobs[
|
99
224
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
100
225
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
101
226
|
]
|
102
|
-
logprobs_cumsum = torch.cumsum(
|
227
|
+
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
103
228
|
|
104
229
|
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
105
230
|
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
106
231
|
end = start + seq_lens - 2
|
107
|
-
start.clamp_(min=0, max=
|
108
|
-
end.clamp_(min=0, max=
|
109
|
-
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] +
|
232
|
+
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
233
|
+
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
234
|
+
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
110
235
|
|
111
236
|
# assert logprobs == [2, _, 2, 4, _]
|
112
|
-
print("logprobs",
|
237
|
+
print("token logprobs", token_logprobs)
|
113
238
|
print("start", start)
|
114
239
|
print("end", end)
|
115
240
|
print("sum_logp", sum_logp)
|
241
|
+
|
242
|
+
|
243
|
+
if __name__ == "__main__":
|
244
|
+
test()
|
@@ -1,46 +1,42 @@
|
|
1
|
+
"""Radix attention."""
|
2
|
+
|
1
3
|
import torch
|
2
|
-
from
|
4
|
+
from flashinfer.cascade import merge_state
|
5
|
+
from torch import nn
|
6
|
+
|
7
|
+
from sglang.global_config import global_config
|
3
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
4
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
5
|
-
from sglang.srt.managers.
|
6
|
-
from
|
10
|
+
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
11
|
+
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
7
12
|
|
8
13
|
|
9
14
|
class RadixAttention(nn.Module):
|
10
|
-
def __init__(
|
15
|
+
def __init__(
|
16
|
+
self,
|
17
|
+
num_heads: int,
|
18
|
+
head_dim: int,
|
19
|
+
scaling: float,
|
20
|
+
num_kv_heads: int,
|
21
|
+
layer_id: int,
|
22
|
+
logit_cap: int = -1,
|
23
|
+
):
|
11
24
|
super().__init__()
|
12
25
|
self.tp_q_head_num = num_heads
|
13
26
|
self.tp_k_head_num = num_kv_heads
|
14
27
|
self.tp_v_head_num = num_kv_heads
|
15
28
|
self.head_dim = head_dim
|
29
|
+
self.scaling = scaling
|
16
30
|
self.layer_id = layer_id
|
17
31
|
|
18
|
-
|
19
|
-
|
20
|
-
if global_server_args_dict.get("enable_flashinfer", False):
|
21
|
-
self.prefill_forward = self.prefill_forward_flashinfer
|
22
|
-
self.extend_forward = self.prefill_forward_flashinfer
|
32
|
+
if not global_server_args_dict.get("disable_flashinfer", False):
|
33
|
+
self.extend_forward = self.extend_forward_flashinfer
|
23
34
|
self.decode_forward = self.decode_forward_flashinfer
|
24
35
|
else:
|
25
|
-
self.prefill_forward = self.prefill_forward_triton
|
26
36
|
self.extend_forward = self.extend_forward_triton
|
27
37
|
self.decode_forward = self.decode_forward_triton
|
28
38
|
|
29
|
-
|
30
|
-
o = torch.empty_like(q)
|
31
|
-
|
32
|
-
context_attention_fwd(
|
33
|
-
q.view(-1, self.tp_q_head_num, self.head_dim),
|
34
|
-
k,
|
35
|
-
v,
|
36
|
-
o.view(-1, self.tp_q_head_num, self.head_dim),
|
37
|
-
input_metadata.start_loc,
|
38
|
-
input_metadata.seq_lens,
|
39
|
-
input_metadata.max_seq_len,
|
40
|
-
)
|
41
|
-
self.store_kv_cache(k, v, input_metadata)
|
42
|
-
|
43
|
-
return o
|
39
|
+
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
44
40
|
|
45
41
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
46
42
|
o = torch.empty_like(q)
|
@@ -54,13 +50,15 @@ class RadixAttention(nn.Module):
|
|
54
50
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
55
51
|
input_metadata.req_to_token_pool.req_to_token,
|
56
52
|
input_metadata.req_pool_indices,
|
57
|
-
input_metadata.
|
53
|
+
input_metadata.triton_start_loc,
|
58
54
|
input_metadata.seq_lens,
|
59
|
-
input_metadata.
|
55
|
+
input_metadata.triton_prefix_lens,
|
60
56
|
input_metadata.extend_start_loc,
|
61
57
|
input_metadata.extend_seq_lens,
|
62
|
-
input_metadata.
|
63
|
-
input_metadata.
|
58
|
+
input_metadata.triton_max_seq_len,
|
59
|
+
input_metadata.triton_max_extend_len,
|
60
|
+
sm_scale=self.scaling,
|
61
|
+
logit_cap=self.logit_cap,
|
64
62
|
)
|
65
63
|
|
66
64
|
return o
|
@@ -76,31 +74,54 @@ class RadixAttention(nn.Module):
|
|
76
74
|
o.view(-1, self.tp_q_head_num, self.head_dim),
|
77
75
|
input_metadata.req_to_token_pool.req_to_token,
|
78
76
|
input_metadata.req_pool_indices,
|
79
|
-
input_metadata.
|
77
|
+
input_metadata.triton_start_loc,
|
80
78
|
input_metadata.seq_lens,
|
81
|
-
input_metadata.
|
82
|
-
input_metadata.other_kv_index,
|
79
|
+
input_metadata.triton_max_seq_len,
|
83
80
|
input_metadata.total_num_tokens,
|
81
|
+
sm_scale=self.scaling,
|
82
|
+
logit_cap=self.logit_cap,
|
84
83
|
)
|
85
84
|
|
86
85
|
return o
|
87
86
|
|
88
|
-
def
|
89
|
-
|
90
|
-
|
91
|
-
o = input_metadata.prefill_wrapper.forward(
|
87
|
+
def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
88
|
+
o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
|
92
89
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
93
|
-
|
90
|
+
k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
|
91
|
+
v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
|
92
|
+
causal=True,
|
93
|
+
sm_scale=self.scaling,
|
94
|
+
logits_soft_cap=self.logit_cap,
|
94
95
|
)
|
95
96
|
|
97
|
+
if input_metadata.extend_no_prefix:
|
98
|
+
o = o1
|
99
|
+
else:
|
100
|
+
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
101
|
+
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
102
|
+
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
103
|
+
causal=False,
|
104
|
+
sm_scale=self.scaling,
|
105
|
+
logits_soft_cap=self.logit_cap,
|
106
|
+
)
|
107
|
+
|
108
|
+
o, _ = merge_state(o1, s1, o2, s2)
|
109
|
+
|
110
|
+
self.store_kv_cache(k, v, input_metadata)
|
111
|
+
|
112
|
+
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
|
113
|
+
torch.cuda.synchronize()
|
114
|
+
|
96
115
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
97
116
|
|
98
117
|
def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
|
99
118
|
self.store_kv_cache(k, v, input_metadata)
|
100
119
|
|
101
|
-
o = input_metadata.
|
120
|
+
o = input_metadata.flashinfer_decode_wrapper.forward(
|
102
121
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
103
122
|
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
123
|
+
sm_scale=self.scaling,
|
124
|
+
logits_soft_cap=self.logit_cap,
|
104
125
|
)
|
105
126
|
|
106
127
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
@@ -109,25 +130,13 @@ class RadixAttention(nn.Module):
|
|
109
130
|
k = k.view(-1, self.tp_k_head_num, self.head_dim)
|
110
131
|
v = v.view(-1, self.tp_v_head_num, self.head_dim)
|
111
132
|
|
112
|
-
if input_metadata.forward_mode == ForwardMode.
|
113
|
-
return self.prefill_forward(q, k, v, input_metadata)
|
114
|
-
elif input_metadata.forward_mode == ForwardMode.EXTEND:
|
133
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
115
134
|
return self.extend_forward(q, k, v, input_metadata)
|
116
135
|
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
117
136
|
return self.decode_forward(q, k, v, input_metadata)
|
118
137
|
|
119
138
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
120
139
|
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
140
|
+
key_buffer[input_metadata.out_cache_loc] = cache_k
|
121
141
|
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
122
|
-
|
123
|
-
key_buffer[input_metadata.out_cache_loc] = cache_k
|
124
|
-
value_buffer[input_metadata.out_cache_loc] = cache_v
|
125
|
-
elif input_metadata.out_cache_cont_start is not None:
|
126
|
-
key_buffer[
|
127
|
-
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
128
|
-
] = cache_k
|
129
|
-
value_buffer[
|
130
|
-
input_metadata.out_cache_cont_start : input_metadata.out_cache_cont_end
|
131
|
-
] = cache_v
|
132
|
-
else:
|
133
|
-
raise RuntimeError()
|
142
|
+
value_buffer[input_metadata.out_cache_loc] = cache_v
|
@@ -4,7 +4,8 @@
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
|
-
|
7
|
+
|
8
|
+
from sglang.srt.managers.controller.model_runner import global_server_args_dict
|
8
9
|
from sglang.srt.utils import wrap_kernel_launcher
|
9
10
|
|
10
11
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
@@ -15,6 +16,12 @@ else:
|
|
15
16
|
REDUCE_TORCH_TYPE = torch.float16
|
16
17
|
|
17
18
|
|
19
|
+
@triton.jit
|
20
|
+
def tanh(x):
|
21
|
+
# Tanh is just a scaled sigmoid
|
22
|
+
return 2 * tl.sigmoid(2 * x) - 1
|
23
|
+
|
24
|
+
|
18
25
|
@triton.jit
|
19
26
|
def _fwd_kernel_stage1(
|
20
27
|
Q,
|
@@ -34,6 +41,7 @@ def _fwd_kernel_stage1(
|
|
34
41
|
kv_group_num: tl.constexpr,
|
35
42
|
BLOCK_DMODEL: tl.constexpr,
|
36
43
|
BLOCK_N: tl.constexpr,
|
44
|
+
logit_cap: tl.constexpr,
|
37
45
|
):
|
38
46
|
cur_batch = tl.program_id(0)
|
39
47
|
cur_head = tl.program_id(1)
|
@@ -76,6 +84,10 @@ def _fwd_kernel_stage1(
|
|
76
84
|
).to(REDUCE_TRITON_TYPE)
|
77
85
|
att_value = tl.sum(q[None, :] * k, 1)
|
78
86
|
att_value *= sm_scale
|
87
|
+
|
88
|
+
if logit_cap > 0:
|
89
|
+
att_value = logit_cap * tanh(att_value / logit_cap)
|
90
|
+
|
79
91
|
off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n)
|
80
92
|
tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)
|
81
93
|
|
@@ -95,7 +107,6 @@ def _fwd_kernel_stage2(
|
|
95
107
|
stride_obs,
|
96
108
|
stride_oh,
|
97
109
|
stride_req_to_token_b,
|
98
|
-
other_kv_index, # To fix a NAN issue
|
99
110
|
kv_group_num: tl.constexpr,
|
100
111
|
BLOCK_DMODEL: tl.constexpr,
|
101
112
|
BLOCK_N: tl.constexpr,
|
@@ -126,7 +137,7 @@ def _fwd_kernel_stage2(
|
|
126
137
|
+ cur_batch_req_idx * stride_req_to_token_b
|
127
138
|
+ (start_n + offs_n),
|
128
139
|
mask=(start_n + offs_n) < cur_batch_seq_len,
|
129
|
-
other=
|
140
|
+
other=0,
|
130
141
|
)
|
131
142
|
|
132
143
|
qk = tl.load(
|
@@ -164,13 +175,14 @@ def _token_att_m_fwd(
|
|
164
175
|
B_Start_Loc,
|
165
176
|
B_Seqlen,
|
166
177
|
max_len_in_batch,
|
178
|
+
sm_scale,
|
179
|
+
logit_cap,
|
167
180
|
):
|
168
181
|
BLOCK = 32
|
169
182
|
# shape constraints
|
170
183
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
171
184
|
assert Lq == Lk
|
172
185
|
assert Lk in {16, 32, 64, 128, 256}
|
173
|
-
sm_scale = 1.0 / (Lk**0.5)
|
174
186
|
|
175
187
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
176
188
|
|
@@ -222,6 +234,7 @@ def _token_att_m_fwd(
|
|
222
234
|
kv_group_num=kv_group_num,
|
223
235
|
BLOCK_DMODEL=Lk,
|
224
236
|
BLOCK_N=BLOCK,
|
237
|
+
logit_cap=logit_cap,
|
225
238
|
num_warps=num_warps,
|
226
239
|
num_stages=1,
|
227
240
|
)
|
@@ -236,7 +249,6 @@ def _token_softmax_reducev_fwd(
|
|
236
249
|
b_req_idx,
|
237
250
|
b_start_loc,
|
238
251
|
b_seq_len,
|
239
|
-
other_kv_index,
|
240
252
|
):
|
241
253
|
BLOCK = 64
|
242
254
|
batch, head = b_seq_len.shape[0], logics.shape[0]
|
@@ -263,7 +275,6 @@ def _token_softmax_reducev_fwd(
|
|
263
275
|
o.stride(0),
|
264
276
|
o.stride(1),
|
265
277
|
req_to_tokens.stride(0),
|
266
|
-
other_kv_index,
|
267
278
|
)
|
268
279
|
return
|
269
280
|
|
@@ -281,7 +292,6 @@ def _token_softmax_reducev_fwd(
|
|
281
292
|
o.stride(0),
|
282
293
|
o.stride(1),
|
283
294
|
req_to_tokens.stride(0),
|
284
|
-
other_kv_index,
|
285
295
|
kv_group_num=kv_group_num,
|
286
296
|
BLOCK_DMODEL=v_buffer.shape[-1],
|
287
297
|
BLOCK_N=BLOCK,
|
@@ -301,8 +311,9 @@ def token_attention_fwd(
|
|
301
311
|
b_start_loc,
|
302
312
|
b_seq_len,
|
303
313
|
max_len_in_batch,
|
304
|
-
other_kv_index,
|
305
314
|
total_num_tokens,
|
315
|
+
sm_scale,
|
316
|
+
logit_cap=-1,
|
306
317
|
att_m=None,
|
307
318
|
):
|
308
319
|
if att_m is None:
|
@@ -319,6 +330,8 @@ def token_attention_fwd(
|
|
319
330
|
b_start_loc,
|
320
331
|
b_seq_len,
|
321
332
|
max_len_in_batch,
|
333
|
+
sm_scale,
|
334
|
+
logit_cap,
|
322
335
|
)
|
323
336
|
_token_softmax_reducev_fwd(
|
324
337
|
att_m,
|
@@ -328,5 +341,4 @@ def token_attention_fwd(
|
|
328
341
|
b_req_idx,
|
329
342
|
b_start_loc,
|
330
343
|
b_seq_len,
|
331
|
-
other_kv_index,
|
332
344
|
)
|