sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,24 @@
|
|
1
1
|
import json
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
import numpy as np
|
5
|
-
|
6
4
|
from sglang.global_config import global_config
|
7
5
|
from sglang.lang.backend.base_backend import BaseBackend
|
8
6
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
7
|
+
from sglang.lang.choices import (
|
8
|
+
ChoicesDecision,
|
9
|
+
ChoicesSamplingMethod,
|
10
|
+
token_length_normalized,
|
11
|
+
)
|
9
12
|
from sglang.lang.interpreter import StreamExecutor
|
10
13
|
from sglang.lang.ir import SglSamplingParams
|
11
14
|
from sglang.utils import http_request
|
12
15
|
|
13
16
|
|
14
17
|
class RuntimeEndpoint(BaseBackend):
|
18
|
+
|
15
19
|
def __init__(
|
16
20
|
self,
|
17
21
|
base_url: str,
|
18
|
-
auth_token: Optional[str] = None,
|
19
22
|
api_key: Optional[str] = None,
|
20
23
|
verify: Optional[str] = None,
|
21
24
|
):
|
@@ -23,13 +26,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
23
26
|
self.support_concate_and_append = True
|
24
27
|
|
25
28
|
self.base_url = base_url
|
26
|
-
self.auth_token = auth_token
|
27
29
|
self.api_key = api_key
|
28
30
|
self.verify = verify
|
29
31
|
|
30
32
|
res = http_request(
|
31
33
|
self.base_url + "/get_model_info",
|
32
|
-
auth_token=self.auth_token,
|
33
34
|
api_key=self.api_key,
|
34
35
|
verify=self.verify,
|
35
36
|
)
|
@@ -46,7 +47,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
46
47
|
def flush_cache(self):
|
47
48
|
res = http_request(
|
48
49
|
self.base_url + "/flush_cache",
|
49
|
-
|
50
|
+
api_key=self.api_key,
|
50
51
|
verify=self.verify,
|
51
52
|
)
|
52
53
|
self._assert_success(res)
|
@@ -54,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
54
55
|
def get_server_args(self):
|
55
56
|
res = http_request(
|
56
57
|
self.base_url + "/get_server_args",
|
57
|
-
|
58
|
+
api_key=self.api_key,
|
58
59
|
verify=self.verify,
|
59
60
|
)
|
60
61
|
self._assert_success(res)
|
@@ -67,7 +68,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
67
68
|
res = http_request(
|
68
69
|
self.base_url + "/generate",
|
69
70
|
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
70
|
-
auth_token=self.auth_token,
|
71
71
|
api_key=self.api_key,
|
72
72
|
verify=self.verify,
|
73
73
|
)
|
@@ -79,7 +79,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
79
79
|
res = http_request(
|
80
80
|
self.base_url + "/generate",
|
81
81
|
json=data,
|
82
|
-
auth_token=self.auth_token,
|
83
82
|
api_key=self.api_key,
|
84
83
|
verify=self.verify,
|
85
84
|
)
|
@@ -91,7 +90,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
91
90
|
res = http_request(
|
92
91
|
self.base_url + "/generate",
|
93
92
|
json=data,
|
94
|
-
auth_token=self.auth_token,
|
95
93
|
api_key=self.api_key,
|
96
94
|
verify=self.verify,
|
97
95
|
)
|
@@ -139,7 +137,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
139
137
|
res = http_request(
|
140
138
|
self.base_url + "/generate",
|
141
139
|
json=data,
|
142
|
-
auth_token=self.auth_token,
|
143
140
|
api_key=self.api_key,
|
144
141
|
verify=self.verify,
|
145
142
|
)
|
@@ -193,7 +190,6 @@ class RuntimeEndpoint(BaseBackend):
|
|
193
190
|
self.base_url + "/generate",
|
194
191
|
json=data,
|
195
192
|
stream=True,
|
196
|
-
auth_token=self.auth_token,
|
197
193
|
api_key=self.api_key,
|
198
194
|
verify=self.verify,
|
199
195
|
)
|
@@ -216,21 +212,14 @@ class RuntimeEndpoint(BaseBackend):
|
|
216
212
|
s: StreamExecutor,
|
217
213
|
choices: List[str],
|
218
214
|
temperature: float,
|
219
|
-
|
215
|
+
choices_method: ChoicesSamplingMethod,
|
216
|
+
) -> ChoicesDecision:
|
220
217
|
assert temperature <= 1e-5
|
221
218
|
|
222
219
|
# Cache common prefix
|
223
220
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
224
|
-
self.
|
225
|
-
|
226
|
-
self.base_url + "/generate",
|
227
|
-
json=data,
|
228
|
-
auth_token=self.auth_token,
|
229
|
-
api_key=self.api_key,
|
230
|
-
verify=self.verify,
|
231
|
-
)
|
232
|
-
self._assert_success(res)
|
233
|
-
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
221
|
+
obj = self._generate_http_request(s, data)
|
222
|
+
prompt_len = obj["meta_info"]["prompt_tokens"]
|
234
223
|
|
235
224
|
# Compute logprob
|
236
225
|
data = {
|
@@ -239,40 +228,57 @@ class RuntimeEndpoint(BaseBackend):
|
|
239
228
|
"return_logprob": True,
|
240
229
|
"logprob_start_len": max(prompt_len - 2, 0),
|
241
230
|
}
|
242
|
-
self.
|
243
|
-
|
244
|
-
self.base_url + "/generate",
|
245
|
-
json=data,
|
246
|
-
auth_token=self.auth_token,
|
247
|
-
api_key=self.api_key,
|
248
|
-
verify=self.verify,
|
249
|
-
)
|
250
|
-
self._assert_success(res)
|
251
|
-
obj = res.json()
|
231
|
+
obj = self._generate_http_request(s, data)
|
232
|
+
|
252
233
|
normalized_prompt_logprobs = [
|
253
234
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
254
235
|
]
|
255
|
-
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
256
236
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
257
237
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
258
238
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
239
|
+
# Compute unconditional logprobs if required
|
240
|
+
if choices_method.requires_unconditional_logprobs:
|
241
|
+
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
242
|
+
data = {
|
243
|
+
"input_ids": input_ids,
|
244
|
+
"sampling_params": {"max_new_tokens": 0},
|
245
|
+
"return_logprob": True,
|
246
|
+
}
|
247
|
+
obj = self._generate_http_request(s, data)
|
248
|
+
unconditional_token_logprobs = [
|
249
|
+
r["meta_info"]["input_token_logprobs"] for r in obj
|
250
|
+
]
|
251
|
+
else:
|
252
|
+
unconditional_token_logprobs = None
|
253
|
+
|
254
|
+
return choices_method(
|
255
|
+
choices=choices,
|
256
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
257
|
+
input_token_logprobs=input_token_logprobs,
|
258
|
+
output_token_logprobs=output_token_logprobs,
|
259
|
+
unconditional_token_logprobs=unconditional_token_logprobs,
|
264
260
|
)
|
265
261
|
|
266
262
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
267
263
|
res = http_request(
|
268
264
|
self.base_url + "/concate_and_append_request",
|
269
265
|
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
270
|
-
auth_token=self.auth_token,
|
271
266
|
api_key=self.api_key,
|
272
267
|
verify=self.verify,
|
273
268
|
)
|
274
269
|
self._assert_success(res)
|
275
270
|
|
271
|
+
def _generate_http_request(self, s: StreamExecutor, data):
|
272
|
+
self._add_images(s, data)
|
273
|
+
res = http_request(
|
274
|
+
self.base_url + "/generate",
|
275
|
+
json=data,
|
276
|
+
api_key=self.api_key,
|
277
|
+
verify=self.verify,
|
278
|
+
)
|
279
|
+
self._assert_success(res)
|
280
|
+
return res.json()
|
281
|
+
|
276
282
|
def _add_images(self, s: StreamExecutor, data):
|
277
283
|
if s.images_:
|
278
284
|
assert len(s.images_) == 1, "Only support one image."
|
sglang/lang/choices.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class ChoicesDecision:
|
10
|
+
decision: str
|
11
|
+
meta_info: Optional[Dict[str, Any]] = None
|
12
|
+
|
13
|
+
|
14
|
+
class ChoicesSamplingMethod(ABC):
|
15
|
+
|
16
|
+
@property
|
17
|
+
def requires_unconditional_logprobs(self) -> bool:
|
18
|
+
return False
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def __call__(
|
22
|
+
self,
|
23
|
+
*,
|
24
|
+
choices: List[str],
|
25
|
+
normalized_prompt_logprobs: List[float],
|
26
|
+
input_token_logprobs: List[List[Any]],
|
27
|
+
output_token_logprobs: List[List[Any]],
|
28
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
29
|
+
) -> ChoicesDecision: ...
|
30
|
+
|
31
|
+
|
32
|
+
class TokenLengthNormalized(ChoicesSamplingMethod):
|
33
|
+
|
34
|
+
def __call__(
|
35
|
+
self,
|
36
|
+
*,
|
37
|
+
choices: List[str],
|
38
|
+
normalized_prompt_logprobs: List[float],
|
39
|
+
input_token_logprobs: List[List[Any]],
|
40
|
+
output_token_logprobs: List[List[Any]],
|
41
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
42
|
+
) -> ChoicesDecision:
|
43
|
+
"""Select the option with the highest token length normalized prompt logprob."""
|
44
|
+
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
|
45
|
+
meta_info = {
|
46
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
47
|
+
"input_token_logprobs": input_token_logprobs,
|
48
|
+
"output_token_logprobs": output_token_logprobs,
|
49
|
+
}
|
50
|
+
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
51
|
+
|
52
|
+
|
53
|
+
token_length_normalized = TokenLengthNormalized()
|
54
|
+
|
55
|
+
|
56
|
+
class GreedyTokenSelection(ChoicesSamplingMethod):
|
57
|
+
|
58
|
+
def __call__(
|
59
|
+
self,
|
60
|
+
*,
|
61
|
+
choices: List[str],
|
62
|
+
normalized_prompt_logprobs: List[float],
|
63
|
+
input_token_logprobs: List[List[Any]],
|
64
|
+
output_token_logprobs: List[List[Any]],
|
65
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
66
|
+
) -> ChoicesDecision:
|
67
|
+
"""Select the option based on greedy logprob selection. For overlapping options
|
68
|
+
where one option is a subset of a longer option, extend the shorter option using
|
69
|
+
its average logprob for comparison against the longer option."""
|
70
|
+
|
71
|
+
num_options = len(choices)
|
72
|
+
max_tokens = max(len(option) for option in input_token_logprobs)
|
73
|
+
logprob_matrix = self._build_logprob_matrix(
|
74
|
+
input_token_logprobs, max_tokens, num_options
|
75
|
+
)
|
76
|
+
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
|
77
|
+
|
78
|
+
best_choice = choices[remaining[0]]
|
79
|
+
meta_info = {
|
80
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
81
|
+
"input_token_logprobs": input_token_logprobs,
|
82
|
+
"output_token_logprobs": output_token_logprobs,
|
83
|
+
"greedy_logprob_matrix": logprob_matrix.tolist(),
|
84
|
+
}
|
85
|
+
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
86
|
+
|
87
|
+
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
|
88
|
+
logprob_matrix = np.zeros((num_options, max_tokens))
|
89
|
+
for i, option in enumerate(input_token_logprobs):
|
90
|
+
actual_logprobs = [token[0] for token in option]
|
91
|
+
avg_logprob = np.mean(actual_logprobs)
|
92
|
+
logprob_matrix[i, : len(option)] = actual_logprobs
|
93
|
+
if len(option) < max_tokens:
|
94
|
+
logprob_matrix[i, len(option) :] = avg_logprob
|
95
|
+
return logprob_matrix
|
96
|
+
|
97
|
+
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
|
98
|
+
remaining = np.arange(num_options)
|
99
|
+
for j in range(max_tokens):
|
100
|
+
max_logprob = np.max(logprob_matrix[remaining, j])
|
101
|
+
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
|
102
|
+
if len(remaining) == 1:
|
103
|
+
break
|
104
|
+
return remaining
|
105
|
+
|
106
|
+
|
107
|
+
greedy_token_selection = GreedyTokenSelection()
|
108
|
+
|
109
|
+
|
110
|
+
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
|
111
|
+
|
112
|
+
@property
|
113
|
+
def requires_unconditional_logprobs(self) -> bool:
|
114
|
+
return True
|
115
|
+
|
116
|
+
def __call__(
|
117
|
+
self,
|
118
|
+
*,
|
119
|
+
choices: List[str],
|
120
|
+
normalized_prompt_logprobs: List[float],
|
121
|
+
input_token_logprobs: List[List[Any]],
|
122
|
+
output_token_logprobs: List[List[Any]],
|
123
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
124
|
+
) -> ChoicesDecision:
|
125
|
+
"""Select the option with the highest average token logprob once normalized by
|
126
|
+
the unconditional token logprobs.
|
127
|
+
|
128
|
+
The first unconditional token logprob is assumed to be None. If so, it is
|
129
|
+
replaced with 0 for the purposes of normalization."""
|
130
|
+
|
131
|
+
if unconditional_token_logprobs is None:
|
132
|
+
raise ValueError(
|
133
|
+
"Unconditional token logprobs are required for this method."
|
134
|
+
)
|
135
|
+
|
136
|
+
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
|
137
|
+
input_token_logprobs, unconditional_token_logprobs
|
138
|
+
)
|
139
|
+
|
140
|
+
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
|
141
|
+
meta_info = {
|
142
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
143
|
+
"input_token_logprobs": input_token_logprobs,
|
144
|
+
"output_token_logprobs": output_token_logprobs,
|
145
|
+
"unconditional_token_logprobs": unconditional_token_logprobs,
|
146
|
+
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
|
147
|
+
}
|
148
|
+
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
149
|
+
|
150
|
+
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
|
151
|
+
normalized_unconditional_prompt_logprobs = []
|
152
|
+
for inputs, unconditionals in zip(
|
153
|
+
input_token_logprobs, unconditional_token_logprobs
|
154
|
+
):
|
155
|
+
inputs_logprobs = np.array([token[0] for token in inputs])
|
156
|
+
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
|
157
|
+
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
|
158
|
+
normalized_unconditional_prompt_logprobs.append(
|
159
|
+
float(np.mean(inputs_logprobs - unconditionals_logprobs))
|
160
|
+
)
|
161
|
+
return normalized_unconditional_prompt_logprobs
|
162
|
+
|
163
|
+
|
164
|
+
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
|
sglang/lang/interpreter.py
CHANGED
@@ -538,24 +538,17 @@ class StreamExecutor:
|
|
538
538
|
self.stream_var_event[name].set()
|
539
539
|
|
540
540
|
def _execute_select(self, expr: SglSelect):
|
541
|
-
(
|
542
|
-
|
543
|
-
|
544
|
-
input_token_logprobs,
|
545
|
-
output_token_logprobs,
|
546
|
-
) = self.backend.select(self, expr.choices, expr.temperature)
|
541
|
+
choices_decision = self.backend.select(
|
542
|
+
self, expr.choices, expr.temperature, expr.choices_method
|
543
|
+
)
|
547
544
|
if expr.name is not None:
|
548
545
|
name = expr.name
|
549
|
-
self.variables[name] = decision
|
550
|
-
self.meta_info[name] =
|
551
|
-
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
552
|
-
"input_token_logprobs": input_token_logprobs,
|
553
|
-
"output_token_logprobs": output_token_logprobs,
|
554
|
-
}
|
546
|
+
self.variables[name] = choices_decision.decision
|
547
|
+
self.meta_info[name] = choices_decision.meta_info
|
555
548
|
self.variable_event[name].set()
|
556
549
|
if self.stream_var_event:
|
557
550
|
self.stream_var_event[name].set()
|
558
|
-
self.text_ += decision
|
551
|
+
self.text_ += choices_decision.decision
|
559
552
|
|
560
553
|
def _execute_variable(self, expr: SglVariable):
|
561
554
|
src_executor = expr.source_stream_executor
|
sglang/lang/ir.py
CHANGED
@@ -6,6 +6,7 @@ import warnings
|
|
6
6
|
from typing import List, Optional, Union
|
7
7
|
|
8
8
|
from sglang.global_config import global_config
|
9
|
+
from sglang.lang.choices import ChoicesSamplingMethod
|
9
10
|
|
10
11
|
REGEX_INT = r"[-+]?[0-9]+"
|
11
12
|
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
|
|
461
462
|
|
462
463
|
|
463
464
|
class SglSelect(SglExpr):
|
464
|
-
|
465
|
+
|
466
|
+
def __init__(
|
467
|
+
self,
|
468
|
+
name: str,
|
469
|
+
choices: List[str],
|
470
|
+
temperature: float,
|
471
|
+
choices_method: ChoicesSamplingMethod,
|
472
|
+
):
|
465
473
|
super().__init__()
|
466
474
|
self.name = name
|
467
475
|
self.choices = choices
|
468
476
|
self.temperature = temperature
|
477
|
+
self.choices_method = choices_method
|
469
478
|
|
470
479
|
def __repr__(self):
|
471
|
-
return f"Select({self.name}, choices={self.choices})"
|
480
|
+
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
|
472
481
|
|
473
482
|
|
474
483
|
class SglFork(SglExpr):
|
@@ -19,7 +19,7 @@ import functools
|
|
19
19
|
import json
|
20
20
|
import os
|
21
21
|
import warnings
|
22
|
-
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
22
|
+
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
|
23
23
|
|
24
24
|
from huggingface_hub import snapshot_download
|
25
25
|
from transformers import (
|
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
|
|
259
259
|
Literal["all"], AbstractSet[str]
|
260
260
|
] = set(), # noqa: B006
|
261
261
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
262
|
-
) ->
|
262
|
+
) -> List[int]:
|
263
263
|
if isinstance(allowed_special, set):
|
264
264
|
allowed_special |= self._default_allowed_special
|
265
265
|
return tiktoken.Encoding.encode(
|
@@ -57,6 +57,8 @@ def _fwd_kernel(
|
|
57
57
|
stride_buf_vh,
|
58
58
|
stride_req_to_tokens_b,
|
59
59
|
BLOCK_DMODEL: tl.constexpr,
|
60
|
+
BLOCK_DPE: tl.constexpr,
|
61
|
+
BLOCK_DV: tl.constexpr,
|
60
62
|
BLOCK_M: tl.constexpr,
|
61
63
|
BLOCK_N: tl.constexpr,
|
62
64
|
logit_cap: tl.constexpr,
|
@@ -75,8 +77,10 @@ def _fwd_kernel(
|
|
75
77
|
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
|
76
78
|
|
77
79
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
80
|
+
offs_dv = tl.arange(0, BLOCK_DV)
|
78
81
|
offs_m = tl.arange(0, BLOCK_M)
|
79
82
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
83
|
+
|
80
84
|
offs_q = (
|
81
85
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
82
86
|
* stride_qbs
|
@@ -85,10 +89,20 @@ def _fwd_kernel(
|
|
85
89
|
)
|
86
90
|
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
87
91
|
|
92
|
+
if BLOCK_DPE > 0:
|
93
|
+
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
94
|
+
offs_qpe = (
|
95
|
+
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
96
|
+
* stride_qbs
|
97
|
+
+ cur_head * stride_qh
|
98
|
+
+ offs_dpe[None, :]
|
99
|
+
)
|
100
|
+
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
|
101
|
+
|
88
102
|
# stage1: compute scores with prefix
|
89
103
|
offs_n = tl.arange(0, BLOCK_N)
|
90
104
|
|
91
|
-
acc = tl.zeros([BLOCK_M,
|
105
|
+
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
|
92
106
|
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
|
93
107
|
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
94
108
|
|
@@ -110,6 +124,18 @@ def _fwd_kernel(
|
|
110
124
|
|
111
125
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
112
126
|
qk += tl.dot(q, k)
|
127
|
+
if BLOCK_DPE > 0:
|
128
|
+
offs_kpe = (
|
129
|
+
offs_kv_loc[None, :] * stride_buf_kbs
|
130
|
+
+ cur_kv_head * stride_buf_kh
|
131
|
+
+ offs_dpe[:, None]
|
132
|
+
)
|
133
|
+
kpe = tl.load(
|
134
|
+
K_Buffer + offs_kpe,
|
135
|
+
mask=mask_n[None, :],
|
136
|
+
other=0.0,
|
137
|
+
)
|
138
|
+
qk += tl.dot(qpe, kpe)
|
113
139
|
qk *= sm_scale
|
114
140
|
|
115
141
|
if logit_cap > 0:
|
@@ -125,7 +151,7 @@ def _fwd_kernel(
|
|
125
151
|
offs_buf_v = (
|
126
152
|
offs_kv_loc[:, None] * stride_buf_vbs
|
127
153
|
+ cur_kv_head * stride_buf_vh
|
128
|
-
+
|
154
|
+
+ offs_dv[None, :]
|
129
155
|
)
|
130
156
|
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
131
157
|
p = p.to(v.dtype)
|
@@ -150,6 +176,21 @@ def _fwd_kernel(
|
|
150
176
|
|
151
177
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
152
178
|
qk += tl.dot(q, k)
|
179
|
+
|
180
|
+
if BLOCK_DPE > 0:
|
181
|
+
offs_kpe = (
|
182
|
+
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
|
183
|
+
* stride_kbs
|
184
|
+
+ cur_kv_head * stride_kh
|
185
|
+
+ offs_dpe[:, None]
|
186
|
+
)
|
187
|
+
kpe = tl.load(
|
188
|
+
K_Extend + offs_kpe,
|
189
|
+
mask=mask_n[None, :],
|
190
|
+
other=0.0,
|
191
|
+
)
|
192
|
+
qk += tl.dot(qpe, kpe)
|
193
|
+
|
153
194
|
qk *= sm_scale
|
154
195
|
|
155
196
|
if logit_cap > 0:
|
@@ -169,7 +210,7 @@ def _fwd_kernel(
|
|
169
210
|
offs_v = (
|
170
211
|
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
|
171
212
|
+ cur_kv_head * stride_vh
|
172
|
-
+
|
213
|
+
+ offs_dv[None, :]
|
173
214
|
)
|
174
215
|
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
175
216
|
p = p.to(v.dtype)
|
@@ -181,7 +222,7 @@ def _fwd_kernel(
|
|
181
222
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
182
223
|
* stride_obs
|
183
224
|
+ cur_head * stride_oh
|
184
|
-
+
|
225
|
+
+ offs_dv[None, :]
|
185
226
|
)
|
186
227
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
187
228
|
|
@@ -217,8 +258,17 @@ def extend_attention_fwd(
|
|
217
258
|
o_extend.shape[-1],
|
218
259
|
)
|
219
260
|
|
220
|
-
assert Lq == Lk and
|
221
|
-
assert Lq in {16, 32, 64, 128, 256}
|
261
|
+
assert Lq == Lk and Lv == Lo
|
262
|
+
assert Lq in {16, 32, 64, 128, 256, 576}
|
263
|
+
assert Lv in {16, 32, 64, 128, 256, 512}
|
264
|
+
|
265
|
+
if Lq == 576:
|
266
|
+
BLOCK_DMODEL = 512
|
267
|
+
BLOCK_DPE = 64
|
268
|
+
else:
|
269
|
+
BLOCK_DMODEL = Lq
|
270
|
+
BLOCK_DPE = 0
|
271
|
+
BLOCK_DV = Lv
|
222
272
|
|
223
273
|
if CUDA_CAPABILITY[0] >= 8:
|
224
274
|
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
|
@@ -260,7 +310,9 @@ def extend_attention_fwd(
|
|
260
310
|
v_buffer.stride(0),
|
261
311
|
v_buffer.stride(1),
|
262
312
|
req_to_tokens.stride(0),
|
263
|
-
BLOCK_DMODEL=
|
313
|
+
BLOCK_DMODEL=BLOCK_DMODEL,
|
314
|
+
BLOCK_DPE=BLOCK_DPE,
|
315
|
+
BLOCK_DV=BLOCK_DV,
|
264
316
|
BLOCK_M=BLOCK_M,
|
265
317
|
BLOCK_N=BLOCK_N,
|
266
318
|
num_warps=num_warps,
|
@@ -25,7 +25,7 @@ from vllm.distributed import (
|
|
25
25
|
tensor_model_parallel_all_gather,
|
26
26
|
)
|
27
27
|
|
28
|
-
from sglang.srt.model_executor.
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
@@ -22,11 +22,8 @@ from torch import nn
|
|
22
22
|
from sglang.global_config import global_config
|
23
23
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
24
24
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
25
|
-
from sglang.srt.model_executor.
|
26
|
-
|
27
|
-
InputMetadata,
|
28
|
-
global_server_args_dict,
|
29
|
-
)
|
25
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
26
|
+
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
30
27
|
|
31
28
|
|
32
29
|
class RadixAttention(nn.Module):
|
@@ -38,16 +35,22 @@ class RadixAttention(nn.Module):
|
|
38
35
|
num_kv_heads: int,
|
39
36
|
layer_id: int,
|
40
37
|
logit_cap: int = -1,
|
38
|
+
v_head_dim: int = -1,
|
41
39
|
):
|
42
40
|
super().__init__()
|
43
41
|
self.tp_q_head_num = num_heads
|
44
42
|
self.tp_k_head_num = num_kv_heads
|
45
43
|
self.tp_v_head_num = num_kv_heads
|
46
44
|
self.head_dim = head_dim
|
45
|
+
self.qk_head_dim = head_dim
|
46
|
+
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
|
47
47
|
self.scaling = scaling
|
48
48
|
self.layer_id = layer_id
|
49
49
|
|
50
|
-
if
|
50
|
+
if (
|
51
|
+
not global_server_args_dict.get("disable_flashinfer", False)
|
52
|
+
and self.qk_head_dim == self.v_head_dim
|
53
|
+
):
|
51
54
|
self.extend_forward = self.extend_forward_flashinfer
|
52
55
|
self.decode_forward = self.decode_forward_flashinfer
|
53
56
|
else:
|
@@ -57,13 +60,17 @@ class RadixAttention(nn.Module):
|
|
57
60
|
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
|
58
61
|
|
59
62
|
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
60
|
-
|
63
|
+
if self.qk_head_dim != self.v_head_dim:
|
64
|
+
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
65
|
+
else:
|
66
|
+
o = torch.empty_like(q)
|
67
|
+
|
61
68
|
self.store_kv_cache(k, v, input_metadata)
|
62
69
|
extend_attention_fwd(
|
63
|
-
q.view(-1, self.tp_q_head_num, self.
|
70
|
+
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
64
71
|
k.contiguous(),
|
65
72
|
v.contiguous(),
|
66
|
-
o.view(-1, self.tp_q_head_num, self.
|
73
|
+
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
67
74
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
68
75
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
69
76
|
input_metadata.req_to_token_pool.req_to_token,
|
@@ -82,14 +89,17 @@ class RadixAttention(nn.Module):
|
|
82
89
|
return o
|
83
90
|
|
84
91
|
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
|
85
|
-
|
92
|
+
if self.qk_head_dim != self.v_head_dim:
|
93
|
+
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
|
94
|
+
else:
|
95
|
+
o = torch.empty_like(q)
|
86
96
|
self.store_kv_cache(k, v, input_metadata)
|
87
97
|
|
88
98
|
token_attention_fwd(
|
89
|
-
q.view(-1, self.tp_q_head_num, self.
|
99
|
+
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
90
100
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
91
101
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
92
|
-
o.view(-1, self.tp_q_head_num, self.
|
102
|
+
o.view(-1, self.tp_q_head_num, self.v_head_dim),
|
93
103
|
input_metadata.req_to_token_pool.req_to_token,
|
94
104
|
input_metadata.req_pool_indices,
|
95
105
|
input_metadata.triton_start_loc,
|
@@ -160,8 +170,8 @@ class RadixAttention(nn.Module):
|
|
160
170
|
return o.view(-1, self.tp_q_head_num * self.head_dim)
|
161
171
|
|
162
172
|
def forward(self, q, k, v, input_metadata: InputMetadata):
|
163
|
-
k = k.view(-1, self.tp_k_head_num, self.
|
164
|
-
v = v.view(-1, self.tp_v_head_num, self.
|
173
|
+
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
174
|
+
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
165
175
|
|
166
176
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
167
177
|
return self.extend_forward(q, k, v, input_metadata)
|