sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,21 @@
|
|
1
1
|
import json
|
2
2
|
from typing import List, Optional
|
3
3
|
|
4
|
-
import numpy as np
|
5
|
-
|
6
4
|
from sglang.global_config import global_config
|
7
5
|
from sglang.lang.backend.base_backend import BaseBackend
|
8
6
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
7
|
+
from sglang.lang.choices import (
|
8
|
+
ChoicesDecision,
|
9
|
+
ChoicesSamplingMethod,
|
10
|
+
token_length_normalized,
|
11
|
+
)
|
9
12
|
from sglang.lang.interpreter import StreamExecutor
|
10
13
|
from sglang.lang.ir import SglSamplingParams
|
11
14
|
from sglang.utils import http_request
|
12
15
|
|
13
16
|
|
14
17
|
class RuntimeEndpoint(BaseBackend):
|
18
|
+
|
15
19
|
def __init__(
|
16
20
|
self,
|
17
21
|
base_url: str,
|
@@ -43,7 +47,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
43
47
|
def flush_cache(self):
|
44
48
|
res = http_request(
|
45
49
|
self.base_url + "/flush_cache",
|
46
|
-
|
50
|
+
api_key=self.api_key,
|
47
51
|
verify=self.verify,
|
48
52
|
)
|
49
53
|
self._assert_success(res)
|
@@ -51,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
|
|
51
55
|
def get_server_args(self):
|
52
56
|
res = http_request(
|
53
57
|
self.base_url + "/get_server_args",
|
54
|
-
|
58
|
+
api_key=self.api_key,
|
55
59
|
verify=self.verify,
|
56
60
|
)
|
57
61
|
self._assert_success(res)
|
@@ -208,20 +212,14 @@ class RuntimeEndpoint(BaseBackend):
|
|
208
212
|
s: StreamExecutor,
|
209
213
|
choices: List[str],
|
210
214
|
temperature: float,
|
211
|
-
|
215
|
+
choices_method: ChoicesSamplingMethod,
|
216
|
+
) -> ChoicesDecision:
|
212
217
|
assert temperature <= 1e-5
|
213
218
|
|
214
219
|
# Cache common prefix
|
215
220
|
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
216
|
-
self.
|
217
|
-
|
218
|
-
self.base_url + "/generate",
|
219
|
-
json=data,
|
220
|
-
api_key=self.api_key,
|
221
|
-
verify=self.verify,
|
222
|
-
)
|
223
|
-
self._assert_success(res)
|
224
|
-
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
221
|
+
obj = self._generate_http_request(s, data)
|
222
|
+
prompt_len = obj["meta_info"]["prompt_tokens"]
|
225
223
|
|
226
224
|
# Compute logprob
|
227
225
|
data = {
|
@@ -230,27 +228,35 @@ class RuntimeEndpoint(BaseBackend):
|
|
230
228
|
"return_logprob": True,
|
231
229
|
"logprob_start_len": max(prompt_len - 2, 0),
|
232
230
|
}
|
233
|
-
self.
|
234
|
-
|
235
|
-
self.base_url + "/generate",
|
236
|
-
json=data,
|
237
|
-
api_key=self.api_key,
|
238
|
-
verify=self.verify,
|
239
|
-
)
|
240
|
-
self._assert_success(res)
|
241
|
-
obj = res.json()
|
231
|
+
obj = self._generate_http_request(s, data)
|
232
|
+
|
242
233
|
normalized_prompt_logprobs = [
|
243
234
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
244
235
|
]
|
245
|
-
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
246
236
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
247
237
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
248
238
|
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
239
|
+
# Compute unconditional logprobs if required
|
240
|
+
if choices_method.requires_unconditional_logprobs:
|
241
|
+
input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
|
242
|
+
data = {
|
243
|
+
"input_ids": input_ids,
|
244
|
+
"sampling_params": {"max_new_tokens": 0},
|
245
|
+
"return_logprob": True,
|
246
|
+
}
|
247
|
+
obj = self._generate_http_request(s, data)
|
248
|
+
unconditional_token_logprobs = [
|
249
|
+
r["meta_info"]["input_token_logprobs"] for r in obj
|
250
|
+
]
|
251
|
+
else:
|
252
|
+
unconditional_token_logprobs = None
|
253
|
+
|
254
|
+
return choices_method(
|
255
|
+
choices=choices,
|
256
|
+
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
257
|
+
input_token_logprobs=input_token_logprobs,
|
258
|
+
output_token_logprobs=output_token_logprobs,
|
259
|
+
unconditional_token_logprobs=unconditional_token_logprobs,
|
254
260
|
)
|
255
261
|
|
256
262
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
@@ -262,6 +268,17 @@ class RuntimeEndpoint(BaseBackend):
|
|
262
268
|
)
|
263
269
|
self._assert_success(res)
|
264
270
|
|
271
|
+
def _generate_http_request(self, s: StreamExecutor, data):
|
272
|
+
self._add_images(s, data)
|
273
|
+
res = http_request(
|
274
|
+
self.base_url + "/generate",
|
275
|
+
json=data,
|
276
|
+
api_key=self.api_key,
|
277
|
+
verify=self.verify,
|
278
|
+
)
|
279
|
+
self._assert_success(res)
|
280
|
+
return res.json()
|
281
|
+
|
265
282
|
def _add_images(self, s: StreamExecutor, data):
|
266
283
|
if s.images_:
|
267
284
|
assert len(s.images_) == 1, "Only support one image."
|
sglang/lang/choices.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class ChoicesDecision:
|
10
|
+
decision: str
|
11
|
+
meta_info: Optional[Dict[str, Any]] = None
|
12
|
+
|
13
|
+
|
14
|
+
class ChoicesSamplingMethod(ABC):
|
15
|
+
|
16
|
+
@property
|
17
|
+
def requires_unconditional_logprobs(self) -> bool:
|
18
|
+
return False
|
19
|
+
|
20
|
+
@abstractmethod
|
21
|
+
def __call__(
|
22
|
+
self,
|
23
|
+
*,
|
24
|
+
choices: List[str],
|
25
|
+
normalized_prompt_logprobs: List[float],
|
26
|
+
input_token_logprobs: List[List[Any]],
|
27
|
+
output_token_logprobs: List[List[Any]],
|
28
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
29
|
+
) -> ChoicesDecision: ...
|
30
|
+
|
31
|
+
|
32
|
+
class TokenLengthNormalized(ChoicesSamplingMethod):
|
33
|
+
|
34
|
+
def __call__(
|
35
|
+
self,
|
36
|
+
*,
|
37
|
+
choices: List[str],
|
38
|
+
normalized_prompt_logprobs: List[float],
|
39
|
+
input_token_logprobs: List[List[Any]],
|
40
|
+
output_token_logprobs: List[List[Any]],
|
41
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
42
|
+
) -> ChoicesDecision:
|
43
|
+
"""Select the option with the highest token length normalized prompt logprob."""
|
44
|
+
best_choice = choices[np.argmax(normalized_prompt_logprobs)]
|
45
|
+
meta_info = {
|
46
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
47
|
+
"input_token_logprobs": input_token_logprobs,
|
48
|
+
"output_token_logprobs": output_token_logprobs,
|
49
|
+
}
|
50
|
+
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
51
|
+
|
52
|
+
|
53
|
+
token_length_normalized = TokenLengthNormalized()
|
54
|
+
|
55
|
+
|
56
|
+
class GreedyTokenSelection(ChoicesSamplingMethod):
|
57
|
+
|
58
|
+
def __call__(
|
59
|
+
self,
|
60
|
+
*,
|
61
|
+
choices: List[str],
|
62
|
+
normalized_prompt_logprobs: List[float],
|
63
|
+
input_token_logprobs: List[List[Any]],
|
64
|
+
output_token_logprobs: List[List[Any]],
|
65
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
66
|
+
) -> ChoicesDecision:
|
67
|
+
"""Select the option based on greedy logprob selection. For overlapping options
|
68
|
+
where one option is a subset of a longer option, extend the shorter option using
|
69
|
+
its average logprob for comparison against the longer option."""
|
70
|
+
|
71
|
+
num_options = len(choices)
|
72
|
+
max_tokens = max(len(option) for option in input_token_logprobs)
|
73
|
+
logprob_matrix = self._build_logprob_matrix(
|
74
|
+
input_token_logprobs, max_tokens, num_options
|
75
|
+
)
|
76
|
+
remaining = self._greedy_selection(logprob_matrix, num_options, max_tokens)
|
77
|
+
|
78
|
+
best_choice = choices[remaining[0]]
|
79
|
+
meta_info = {
|
80
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
81
|
+
"input_token_logprobs": input_token_logprobs,
|
82
|
+
"output_token_logprobs": output_token_logprobs,
|
83
|
+
"greedy_logprob_matrix": logprob_matrix.tolist(),
|
84
|
+
}
|
85
|
+
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
86
|
+
|
87
|
+
def _build_logprob_matrix(self, input_token_logprobs, max_tokens, num_options):
|
88
|
+
logprob_matrix = np.zeros((num_options, max_tokens))
|
89
|
+
for i, option in enumerate(input_token_logprobs):
|
90
|
+
actual_logprobs = [token[0] for token in option]
|
91
|
+
avg_logprob = np.mean(actual_logprobs)
|
92
|
+
logprob_matrix[i, : len(option)] = actual_logprobs
|
93
|
+
if len(option) < max_tokens:
|
94
|
+
logprob_matrix[i, len(option) :] = avg_logprob
|
95
|
+
return logprob_matrix
|
96
|
+
|
97
|
+
def _greedy_selection(self, logprob_matrix, num_options, max_tokens):
|
98
|
+
remaining = np.arange(num_options)
|
99
|
+
for j in range(max_tokens):
|
100
|
+
max_logprob = np.max(logprob_matrix[remaining, j])
|
101
|
+
remaining = remaining[logprob_matrix[remaining, j] == max_logprob]
|
102
|
+
if len(remaining) == 1:
|
103
|
+
break
|
104
|
+
return remaining
|
105
|
+
|
106
|
+
|
107
|
+
greedy_token_selection = GreedyTokenSelection()
|
108
|
+
|
109
|
+
|
110
|
+
class UnconditionalLikelihoodNormalized(ChoicesSamplingMethod):
|
111
|
+
|
112
|
+
@property
|
113
|
+
def requires_unconditional_logprobs(self) -> bool:
|
114
|
+
return True
|
115
|
+
|
116
|
+
def __call__(
|
117
|
+
self,
|
118
|
+
*,
|
119
|
+
choices: List[str],
|
120
|
+
normalized_prompt_logprobs: List[float],
|
121
|
+
input_token_logprobs: List[List[Any]],
|
122
|
+
output_token_logprobs: List[List[Any]],
|
123
|
+
unconditional_token_logprobs: Optional[List[List[Any]]] = None,
|
124
|
+
) -> ChoicesDecision:
|
125
|
+
"""Select the option with the highest average token logprob once normalized by
|
126
|
+
the unconditional token logprobs.
|
127
|
+
|
128
|
+
The first unconditional token logprob is assumed to be None. If so, it is
|
129
|
+
replaced with 0 for the purposes of normalization."""
|
130
|
+
|
131
|
+
if unconditional_token_logprobs is None:
|
132
|
+
raise ValueError(
|
133
|
+
"Unconditional token logprobs are required for this method."
|
134
|
+
)
|
135
|
+
|
136
|
+
normalized_unconditional_prompt_logprobs = self._normalize_logprobs(
|
137
|
+
input_token_logprobs, unconditional_token_logprobs
|
138
|
+
)
|
139
|
+
|
140
|
+
best_choice = choices[np.argmax(normalized_unconditional_prompt_logprobs)]
|
141
|
+
meta_info = {
|
142
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
143
|
+
"input_token_logprobs": input_token_logprobs,
|
144
|
+
"output_token_logprobs": output_token_logprobs,
|
145
|
+
"unconditional_token_logprobs": unconditional_token_logprobs,
|
146
|
+
"normalized_unconditional_prompt_logprobs": normalized_unconditional_prompt_logprobs,
|
147
|
+
}
|
148
|
+
return ChoicesDecision(decision=best_choice, meta_info=meta_info)
|
149
|
+
|
150
|
+
def _normalize_logprobs(self, input_token_logprobs, unconditional_token_logprobs):
|
151
|
+
normalized_unconditional_prompt_logprobs = []
|
152
|
+
for inputs, unconditionals in zip(
|
153
|
+
input_token_logprobs, unconditional_token_logprobs
|
154
|
+
):
|
155
|
+
inputs_logprobs = np.array([token[0] for token in inputs])
|
156
|
+
unconditionals_logprobs = np.array([token[0] for token in unconditionals])
|
157
|
+
unconditionals_logprobs[0] = unconditionals_logprobs[0] or 0
|
158
|
+
normalized_unconditional_prompt_logprobs.append(
|
159
|
+
float(np.mean(inputs_logprobs - unconditionals_logprobs))
|
160
|
+
)
|
161
|
+
return normalized_unconditional_prompt_logprobs
|
162
|
+
|
163
|
+
|
164
|
+
unconditional_likelihood_normalized = UnconditionalLikelihoodNormalized()
|
sglang/lang/compiler.py
CHANGED
@@ -125,7 +125,7 @@ class CompiledFunction:
|
|
125
125
|
def run(
|
126
126
|
self,
|
127
127
|
*,
|
128
|
-
max_new_tokens: int =
|
128
|
+
max_new_tokens: int = 128,
|
129
129
|
stop: Union[str, List[str]] = (),
|
130
130
|
temperature: float = 1.0,
|
131
131
|
top_p: float = 1.0,
|
@@ -155,7 +155,7 @@ class CompiledFunction:
|
|
155
155
|
self,
|
156
156
|
batch_kwargs,
|
157
157
|
*,
|
158
|
-
max_new_tokens: int =
|
158
|
+
max_new_tokens: int = 128,
|
159
159
|
stop: Union[str, List[str]] = (),
|
160
160
|
temperature: float = 1.0,
|
161
161
|
top_p: float = 1.0,
|
sglang/lang/interpreter.py
CHANGED
@@ -538,24 +538,17 @@ class StreamExecutor:
|
|
538
538
|
self.stream_var_event[name].set()
|
539
539
|
|
540
540
|
def _execute_select(self, expr: SglSelect):
|
541
|
-
(
|
542
|
-
|
543
|
-
|
544
|
-
input_token_logprobs,
|
545
|
-
output_token_logprobs,
|
546
|
-
) = self.backend.select(self, expr.choices, expr.temperature)
|
541
|
+
choices_decision = self.backend.select(
|
542
|
+
self, expr.choices, expr.temperature, expr.choices_method
|
543
|
+
)
|
547
544
|
if expr.name is not None:
|
548
545
|
name = expr.name
|
549
|
-
self.variables[name] = decision
|
550
|
-
self.meta_info[name] =
|
551
|
-
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
552
|
-
"input_token_logprobs": input_token_logprobs,
|
553
|
-
"output_token_logprobs": output_token_logprobs,
|
554
|
-
}
|
546
|
+
self.variables[name] = choices_decision.decision
|
547
|
+
self.meta_info[name] = choices_decision.meta_info
|
555
548
|
self.variable_event[name].set()
|
556
549
|
if self.stream_var_event:
|
557
550
|
self.stream_var_event[name].set()
|
558
|
-
self.text_ += decision
|
551
|
+
self.text_ += choices_decision.decision
|
559
552
|
|
560
553
|
def _execute_variable(self, expr: SglVariable):
|
561
554
|
src_executor = expr.source_stream_executor
|
sglang/lang/ir.py
CHANGED
@@ -6,6 +6,7 @@ import warnings
|
|
6
6
|
from typing import List, Optional, Union
|
7
7
|
|
8
8
|
from sglang.global_config import global_config
|
9
|
+
from sglang.lang.choices import ChoicesSamplingMethod
|
9
10
|
|
10
11
|
REGEX_INT = r"[-+]?[0-9]+"
|
11
12
|
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
@@ -15,7 +16,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
|
15
16
|
|
16
17
|
@dataclasses.dataclass
|
17
18
|
class SglSamplingParams:
|
18
|
-
max_new_tokens: int =
|
19
|
+
max_new_tokens: int = 128
|
19
20
|
stop: Union[str, List[str]] = ()
|
20
21
|
temperature: float = 1.0
|
21
22
|
top_p: float = 1.0
|
@@ -139,7 +140,7 @@ class SglFunction:
|
|
139
140
|
def run(
|
140
141
|
self,
|
141
142
|
*args,
|
142
|
-
max_new_tokens: int =
|
143
|
+
max_new_tokens: int = 128,
|
143
144
|
stop: Union[str, List[str]] = (),
|
144
145
|
temperature: float = 1.0,
|
145
146
|
top_p: float = 1.0,
|
@@ -178,7 +179,7 @@ class SglFunction:
|
|
178
179
|
self,
|
179
180
|
batch_kwargs,
|
180
181
|
*,
|
181
|
-
max_new_tokens: int =
|
182
|
+
max_new_tokens: int = 128,
|
182
183
|
stop: Union[str, List[str]] = (),
|
183
184
|
temperature: float = 1.0,
|
184
185
|
top_p: float = 1.0,
|
@@ -461,14 +462,22 @@ class SglRoleEnd(SglExpr):
|
|
461
462
|
|
462
463
|
|
463
464
|
class SglSelect(SglExpr):
|
464
|
-
|
465
|
+
|
466
|
+
def __init__(
|
467
|
+
self,
|
468
|
+
name: str,
|
469
|
+
choices: List[str],
|
470
|
+
temperature: float,
|
471
|
+
choices_method: ChoicesSamplingMethod,
|
472
|
+
):
|
465
473
|
super().__init__()
|
466
474
|
self.name = name
|
467
475
|
self.choices = choices
|
468
476
|
self.temperature = temperature
|
477
|
+
self.choices_method = choices_method
|
469
478
|
|
470
479
|
def __repr__(self):
|
471
|
-
return f"Select({self.name}, choices={self.choices})"
|
480
|
+
return f"Select({self.name}, choices={self.choices}, choices_method={self.choices_method})"
|
472
481
|
|
473
482
|
|
474
483
|
class SglFork(SglExpr):
|
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|
20
20
|
|
21
21
|
|
22
22
|
class FSMCache(BaseToolCache):
|
23
|
-
def __init__(
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
tokenizer_path,
|
26
|
+
tokenizer_args_dict,
|
27
|
+
enable=True,
|
28
|
+
skip_tokenizer_init=False,
|
29
|
+
):
|
24
30
|
super().__init__(enable=enable)
|
25
31
|
|
26
|
-
if
|
32
|
+
if (
|
33
|
+
skip_tokenizer_init
|
34
|
+
or tokenizer_path.endswith(".json")
|
35
|
+
or tokenizer_path.endswith(".model")
|
36
|
+
):
|
27
37
|
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
28
38
|
return
|
29
39
|
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
"""Fused operators for activation layers."""
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from flashinfer.activation import silu_and_mul
|
20
|
+
from vllm.model_executor.custom_op import CustomOp
|
21
|
+
|
22
|
+
|
23
|
+
class SiluAndMul(CustomOp):
|
24
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
25
|
+
d = x.shape[-1] // 2
|
26
|
+
return F.silu(x[..., :d]) * x[..., d:]
|
27
|
+
|
28
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
29
|
+
d = x.shape[-1] // 2
|
30
|
+
output_shape = x.shape[:-1] + (d,)
|
31
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
32
|
+
silu_and_mul(x, out)
|
33
|
+
return out
|
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for decoding.
|
18
|
+
"""
|
19
|
+
|
16
20
|
# Adapted from
|
17
21
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
18
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
|
|
194
198
|
tl.store(out_ptrs, acc)
|
195
199
|
|
196
200
|
|
197
|
-
def
|
201
|
+
def _decode_att_m_fwd(
|
198
202
|
q,
|
199
203
|
k_buffer,
|
200
204
|
att_out,
|
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
|
|
254
258
|
)
|
255
259
|
|
256
260
|
|
257
|
-
def
|
261
|
+
def _decode_softmax_reducev_fwd(
|
258
262
|
logics,
|
259
263
|
v_buffer,
|
260
264
|
o,
|
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
|
|
292
296
|
)
|
293
297
|
|
294
298
|
|
295
|
-
def
|
299
|
+
def decode_attention_fwd(
|
296
300
|
q,
|
297
301
|
k_buffer,
|
298
302
|
v_buffer,
|
@@ -312,7 +316,7 @@ def token_attention_fwd(
|
|
312
316
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
313
317
|
)
|
314
318
|
|
315
|
-
|
319
|
+
_decode_att_m_fwd(
|
316
320
|
q,
|
317
321
|
k_buffer,
|
318
322
|
att_m,
|
@@ -324,7 +328,7 @@ def token_attention_fwd(
|
|
324
328
|
sm_scale,
|
325
329
|
logit_cap,
|
326
330
|
)
|
327
|
-
|
331
|
+
_decode_softmax_reducev_fwd(
|
328
332
|
att_m,
|
329
333
|
v_buffer,
|
330
334
|
o,
|
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for prefill.
|
18
|
+
It supporst page size = 1 and prefill with KV cache (i.e. extend).
|
19
|
+
"""
|
20
|
+
|
16
21
|
import torch
|
17
22
|
import triton
|
18
23
|
import triton.language as tl
|
19
24
|
|
20
|
-
from sglang.srt.layers.
|
25
|
+
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
21
26
|
|
22
27
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
23
28
|
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Fused operators for normalization layers."""
|
17
|
+
|
18
|
+
from typing import Optional, Tuple, Union
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import torch.nn as nn
|
22
|
+
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
23
|
+
from vllm.model_executor.custom_op import CustomOp
|
24
|
+
|
25
|
+
|
26
|
+
class RMSNorm(CustomOp):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
hidden_size: int,
|
30
|
+
eps: float = 1e-6,
|
31
|
+
) -> None:
|
32
|
+
super().__init__()
|
33
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
34
|
+
self.variance_epsilon = eps
|
35
|
+
|
36
|
+
def forward_cuda(
|
37
|
+
self,
|
38
|
+
x: torch.Tensor,
|
39
|
+
residual: Optional[torch.Tensor] = None,
|
40
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
41
|
+
|
42
|
+
if residual is not None:
|
43
|
+
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
44
|
+
return x, residual
|
45
|
+
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
46
|
+
return out
|
47
|
+
|
48
|
+
def forward_native(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
residual: Optional[torch.Tensor] = None,
|
52
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
53
|
+
orig_dtype = x.dtype
|
54
|
+
x = x.to(torch.float32)
|
55
|
+
if residual is not None:
|
56
|
+
x = x + residual.to(torch.float32)
|
57
|
+
residual = x.to(orig_dtype)
|
58
|
+
|
59
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
60
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
61
|
+
x = x.to(orig_dtype) * self.weight
|
62
|
+
if residual is None:
|
63
|
+
return x
|
64
|
+
else:
|
65
|
+
return x, residual
|
@@ -25,7 +25,7 @@ from vllm.distributed import (
|
|
25
25
|
tensor_model_parallel_all_gather,
|
26
26
|
)
|
27
27
|
|
28
|
-
from sglang.srt.model_executor.
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
|
|
208
208
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
209
209
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
210
|
|
211
|
+
if hasattr(self.config, "final_logit_softcapping"):
|
212
|
+
all_logits /= self.config.final_logit_softcapping
|
213
|
+
all_logits = torch.tanh(all_logits)
|
214
|
+
all_logits *= self.config.final_logit_softcapping
|
215
|
+
|
211
216
|
all_logprobs = all_logits
|
212
217
|
del all_logits, hidden_states
|
213
218
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
|
3
|
+
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from enum import IntEnum
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
from sglang.srt.model_executor.model_runner import InputMetadata
|
11
|
+
|
12
|
+
|
13
|
+
class PoolingType(IntEnum):
|
14
|
+
LAST = 0
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class EmbeddingPoolerOutput:
|
19
|
+
embeddings: torch.Tensor
|
20
|
+
|
21
|
+
|
22
|
+
class Pooler(nn.Module):
|
23
|
+
"""A layer that pools specific information from hidden states.
|
24
|
+
This layer does the following:
|
25
|
+
1. Extracts specific tokens or aggregates data based on pooling method.
|
26
|
+
2. Normalizes output if specified.
|
27
|
+
3. Returns structured results as `PoolerOutput`.
|
28
|
+
Attributes:
|
29
|
+
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
|
30
|
+
normalize: Whether to normalize the pooled data.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
34
|
+
super().__init__()
|
35
|
+
self.pooling_type = pooling_type
|
36
|
+
self.normalize = normalize
|
37
|
+
|
38
|
+
def forward(
|
39
|
+
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
|
40
|
+
) -> EmbeddingPoolerOutput:
|
41
|
+
if self.pooling_type == PoolingType.LAST:
|
42
|
+
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
|
43
|
+
pooled_data = hidden_states[last_token_indices]
|
44
|
+
else:
|
45
|
+
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
46
|
+
|
47
|
+
if self.normalize:
|
48
|
+
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
49
|
+
|
50
|
+
return EmbeddingPoolerOutput(embeddings=pooled_data)
|