sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +6 -4
- sglang/bench_serving.py +46 -22
- sglang/lang/compiler.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +110 -87
- sglang/srt/managers/tokenizer_manager.py +193 -111
- sglang/srt/managers/tp_worker.py +289 -352
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +168 -105
- sglang/srt/model_executor/model_runner.py +24 -37
- sglang/srt/models/gemma2.py +0 -1
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/qwen2_moe.py +0 -11
- sglang/srt/openai_api/adapter.py +155 -27
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +69 -15
- sglang/srt/server_args.py +26 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +20 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
sglang/bench_latency.py
CHANGED
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
|
152
152
|
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
153
153
|
req.prefix_indices = []
|
154
154
|
req.sampling_params = sampling_params
|
155
|
-
req.
|
155
|
+
req.fill_ids = req.origin_input_ids
|
156
156
|
reqs.append(req)
|
157
157
|
|
158
158
|
return input_ids, reqs
|
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
|
|
163
163
|
):
|
164
164
|
for i in range(len(reqs)):
|
165
165
|
req = reqs[i]
|
166
|
-
req.
|
166
|
+
req.fill_ids += input_ids[i][bench_args.cut_len :]
|
167
167
|
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
168
168
|
i, : bench_args.cut_len
|
169
169
|
]
|
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
|
182
182
|
req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
|
183
183
|
req.prefix_indices = []
|
184
184
|
req.sampling_params = sampling_params
|
185
|
-
req.
|
185
|
+
req.fill_ids = req.origin_input_ids
|
186
186
|
reqs.append(req)
|
187
187
|
|
188
188
|
return reqs
|
@@ -238,7 +238,7 @@ def correctness_test(
|
|
238
238
|
|
239
239
|
# Decode
|
240
240
|
output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
|
241
|
-
for _ in range(bench_args.output_len):
|
241
|
+
for _ in range(bench_args.output_len[0]):
|
242
242
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
243
243
|
for i in range(len(reqs)):
|
244
244
|
output_ids[i].append(next_token_ids[i])
|
@@ -332,6 +332,7 @@ def latency_test(
|
|
332
332
|
)
|
333
333
|
|
334
334
|
# Warm up
|
335
|
+
rank_print("Warmup ...")
|
335
336
|
latency_test_run_once(
|
336
337
|
bench_args.run_name,
|
337
338
|
model_runner,
|
@@ -341,6 +342,7 @@ def latency_test(
|
|
341
342
|
bench_args.input_len[0],
|
342
343
|
4, # shorter decoding to speed up the warmup
|
343
344
|
)
|
345
|
+
rank_print("Benchmark ...")
|
344
346
|
|
345
347
|
# Run the sweep
|
346
348
|
result_list = []
|
sglang/bench_serving.py
CHANGED
@@ -24,7 +24,7 @@ import warnings
|
|
24
24
|
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
|
-
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
27
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
28
28
|
|
29
29
|
import aiohttp
|
30
30
|
import numpy as np
|
@@ -39,6 +39,8 @@ from transformers import (
|
|
39
39
|
|
40
40
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
41
41
|
|
42
|
+
global args
|
43
|
+
|
42
44
|
|
43
45
|
@dataclass
|
44
46
|
class RequestFuncInput:
|
@@ -47,6 +49,7 @@ class RequestFuncInput:
|
|
47
49
|
prompt_len: int
|
48
50
|
output_len: int
|
49
51
|
model: str
|
52
|
+
extra_request_body: Dict[str, Any]
|
50
53
|
|
51
54
|
|
52
55
|
@dataclass
|
@@ -84,6 +87,7 @@ async def async_request_trt_llm(
|
|
84
87
|
"stream": True,
|
85
88
|
"min_length": request_func_input.output_len,
|
86
89
|
"end_id": 1048576,
|
90
|
+
**request_func_input.extra_request_body,
|
87
91
|
}
|
88
92
|
if args.disable_ignore_eos:
|
89
93
|
del payload["min_length"]
|
@@ -154,6 +158,7 @@ async def async_request_openai_completions(
|
|
154
158
|
"max_tokens": request_func_input.output_len,
|
155
159
|
"stream": not args.disable_stream,
|
156
160
|
"ignore_eos": not args.disable_ignore_eos,
|
161
|
+
**request_func_input.extra_request_body,
|
157
162
|
}
|
158
163
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
159
164
|
|
@@ -192,7 +197,8 @@ async def async_request_openai_completions(
|
|
192
197
|
output.ttft = ttft
|
193
198
|
|
194
199
|
# Decoding phase
|
195
|
-
|
200
|
+
else:
|
201
|
+
output.itl.append(timestamp - most_recent_timestamp)
|
196
202
|
|
197
203
|
most_recent_timestamp = timestamp
|
198
204
|
generated_text += data["choices"][0]["text"]
|
@@ -542,6 +548,7 @@ async def benchmark(
|
|
542
548
|
request_rate: float,
|
543
549
|
disable_tqdm: bool,
|
544
550
|
enable_multi: bool,
|
551
|
+
extra_request_body: Dict[str, Any],
|
545
552
|
):
|
546
553
|
if backend in ASYNC_REQUEST_FUNCS:
|
547
554
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -556,6 +563,7 @@ async def benchmark(
|
|
556
563
|
api_url=api_url,
|
557
564
|
prompt_len=test_prompt_len,
|
558
565
|
output_len=test_output_len,
|
566
|
+
extra_request_body=extra_request_body,
|
559
567
|
)
|
560
568
|
test_output = await request_func(request_func_input=test_input)
|
561
569
|
if not test_output.success:
|
@@ -578,6 +586,7 @@ async def benchmark(
|
|
578
586
|
api_url=api_url,
|
579
587
|
prompt_len=prompt_len,
|
580
588
|
output_len=output_len,
|
589
|
+
extra_request_body=extra_request_body,
|
581
590
|
)
|
582
591
|
tasks.append(
|
583
592
|
asyncio.create_task(
|
@@ -660,19 +669,20 @@ async def benchmark(
|
|
660
669
|
"backend": args.backend,
|
661
670
|
"dataset_name": args.dataset_name,
|
662
671
|
"request_rate": request_rate,
|
663
|
-
"
|
664
|
-
"
|
665
|
-
"
|
666
|
-
"
|
667
|
-
"
|
668
|
-
"
|
669
|
-
"
|
670
|
-
"
|
672
|
+
"total_input_tokens": metrics.total_input,
|
673
|
+
"total_output_tokens": metrics.total_output,
|
674
|
+
"total_output_tokens_retokenized": metrics.total_output_retokenized,
|
675
|
+
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
|
676
|
+
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
|
677
|
+
"median_ttft_ms": metrics.median_ttft_ms,
|
678
|
+
"median_itl_ms": metrics.median_itl_ms,
|
679
|
+
"output_throughput": metrics.output_throughput,
|
671
680
|
"sharegpt_output_len": args.sharegpt_output_len,
|
672
681
|
"random_input_len": args.random_input_len,
|
673
682
|
"random_output_len": args.random_output_len,
|
674
683
|
"random_range_ratio": args.random_range_ratio,
|
675
|
-
"
|
684
|
+
"duration": benchmark_duration,
|
685
|
+
"completed": metrics.completed,
|
676
686
|
}
|
677
687
|
else:
|
678
688
|
print(f"Error running benchmark for request rate: {request_rate}")
|
@@ -742,10 +752,18 @@ def check_chat_template(model_path):
|
|
742
752
|
return False
|
743
753
|
|
744
754
|
|
745
|
-
def
|
755
|
+
def run_benchmark(args_: argparse.Namespace):
|
756
|
+
global args
|
757
|
+
args = args_
|
758
|
+
|
759
|
+
set_ulimit()
|
746
760
|
random.seed(args.seed)
|
747
761
|
np.random.seed(args.seed)
|
748
762
|
|
763
|
+
extra_request_body = {}
|
764
|
+
if args.extra_request_body:
|
765
|
+
extra_request_body = json.loads(args.extra_request_body)
|
766
|
+
|
749
767
|
if args.port is None:
|
750
768
|
args.port = {
|
751
769
|
"sglang": 30000,
|
@@ -838,10 +856,11 @@ def fire(args: argparse.Namespace):
|
|
838
856
|
request_rate=rate,
|
839
857
|
disable_tqdm=args.disable_tqdm,
|
840
858
|
enable_multi=args.multi,
|
859
|
+
extra_request_body=extra_request_body,
|
841
860
|
)
|
842
861
|
)
|
843
862
|
else:
|
844
|
-
asyncio.run(
|
863
|
+
return asyncio.run(
|
845
864
|
benchmark(
|
846
865
|
backend=backend,
|
847
866
|
api_url=api_url,
|
@@ -851,6 +870,7 @@ def fire(args: argparse.Namespace):
|
|
851
870
|
request_rate=args.request_rate,
|
852
871
|
disable_tqdm=args.disable_tqdm,
|
853
872
|
enable_multi=args.multi,
|
873
|
+
extra_request_body=extra_request_body,
|
854
874
|
)
|
855
875
|
)
|
856
876
|
|
@@ -949,11 +969,6 @@ if __name__ == "__main__":
|
|
949
969
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
|
950
970
|
)
|
951
971
|
parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
|
952
|
-
parser.add_argument(
|
953
|
-
"--disable-tqdm",
|
954
|
-
action="store_true",
|
955
|
-
help="Specify to disable tqdm progress bar.",
|
956
|
-
)
|
957
972
|
parser.add_argument(
|
958
973
|
"--multi",
|
959
974
|
action="store_true",
|
@@ -966,6 +981,11 @@ if __name__ == "__main__":
|
|
966
981
|
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
967
982
|
)
|
968
983
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
984
|
+
parser.add_argument(
|
985
|
+
"--disable-tqdm",
|
986
|
+
action="store_true",
|
987
|
+
help="Specify to disable tqdm progress bar.",
|
988
|
+
)
|
969
989
|
parser.add_argument(
|
970
990
|
"--disable-stream",
|
971
991
|
action="store_true",
|
@@ -976,8 +996,12 @@ if __name__ == "__main__":
|
|
976
996
|
action="store_true",
|
977
997
|
help="Disable ignoring EOS.",
|
978
998
|
)
|
979
|
-
|
980
|
-
|
981
|
-
|
999
|
+
parser.add_argument(
|
1000
|
+
"--extra-request-body",
|
1001
|
+
metavar='{"key1": "value1", "key2": "value2"}',
|
1002
|
+
type=str,
|
1003
|
+
help="Append given JSON object to the request payload. You can use this to specify"
|
1004
|
+
"additional generate params like sampling params.",
|
1005
|
+
)
|
982
1006
|
args = parser.parse_args()
|
983
|
-
|
1007
|
+
run_benchmark(args)
|
sglang/lang/compiler.py
CHANGED
@@ -125,7 +125,7 @@ class CompiledFunction:
|
|
125
125
|
def run(
|
126
126
|
self,
|
127
127
|
*,
|
128
|
-
max_new_tokens: int =
|
128
|
+
max_new_tokens: int = 128,
|
129
129
|
stop: Union[str, List[str]] = (),
|
130
130
|
temperature: float = 1.0,
|
131
131
|
top_p: float = 1.0,
|
@@ -155,7 +155,7 @@ class CompiledFunction:
|
|
155
155
|
self,
|
156
156
|
batch_kwargs,
|
157
157
|
*,
|
158
|
-
max_new_tokens: int =
|
158
|
+
max_new_tokens: int = 128,
|
159
159
|
stop: Union[str, List[str]] = (),
|
160
160
|
temperature: float = 1.0,
|
161
161
|
top_p: float = 1.0,
|
sglang/lang/ir.py
CHANGED
@@ -16,7 +16,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
|
16
16
|
|
17
17
|
@dataclasses.dataclass
|
18
18
|
class SglSamplingParams:
|
19
|
-
max_new_tokens: int =
|
19
|
+
max_new_tokens: int = 128
|
20
20
|
stop: Union[str, List[str]] = ()
|
21
21
|
temperature: float = 1.0
|
22
22
|
top_p: float = 1.0
|
@@ -140,7 +140,7 @@ class SglFunction:
|
|
140
140
|
def run(
|
141
141
|
self,
|
142
142
|
*args,
|
143
|
-
max_new_tokens: int =
|
143
|
+
max_new_tokens: int = 128,
|
144
144
|
stop: Union[str, List[str]] = (),
|
145
145
|
temperature: float = 1.0,
|
146
146
|
top_p: float = 1.0,
|
@@ -179,7 +179,7 @@ class SglFunction:
|
|
179
179
|
self,
|
180
180
|
batch_kwargs,
|
181
181
|
*,
|
182
|
-
max_new_tokens: int =
|
182
|
+
max_new_tokens: int = 128,
|
183
183
|
stop: Union[str, List[str]] = (),
|
184
184
|
temperature: float = 1.0,
|
185
185
|
top_p: float = 1.0,
|
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
|
20
20
|
|
21
21
|
|
22
22
|
class FSMCache(BaseToolCache):
|
23
|
-
def __init__(
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
tokenizer_path,
|
26
|
+
tokenizer_args_dict,
|
27
|
+
enable=True,
|
28
|
+
skip_tokenizer_init=False,
|
29
|
+
):
|
24
30
|
super().__init__(enable=enable)
|
25
31
|
|
26
|
-
if
|
32
|
+
if (
|
33
|
+
skip_tokenizer_init
|
34
|
+
or tokenizer_path.endswith(".json")
|
35
|
+
or tokenizer_path.endswith(".model")
|
36
|
+
):
|
27
37
|
# Do not support TiktokenTokenizer or SentencePieceTokenizer
|
28
38
|
return
|
29
39
|
|
@@ -0,0 +1,33 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
"""Fused operators for activation layers."""
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import torch.nn as nn
|
18
|
+
import torch.nn.functional as F
|
19
|
+
from flashinfer.activation import silu_and_mul
|
20
|
+
from vllm.model_executor.custom_op import CustomOp
|
21
|
+
|
22
|
+
|
23
|
+
class SiluAndMul(CustomOp):
|
24
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
25
|
+
d = x.shape[-1] // 2
|
26
|
+
return F.silu(x[..., :d]) * x[..., d:]
|
27
|
+
|
28
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
29
|
+
d = x.shape[-1] // 2
|
30
|
+
output_shape = x.shape[:-1] + (d,)
|
31
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
32
|
+
silu_and_mul(x, out)
|
33
|
+
return out
|
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for decoding.
|
18
|
+
"""
|
19
|
+
|
16
20
|
# Adapted from
|
17
21
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
18
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
|
|
194
198
|
tl.store(out_ptrs, acc)
|
195
199
|
|
196
200
|
|
197
|
-
def
|
201
|
+
def _decode_att_m_fwd(
|
198
202
|
q,
|
199
203
|
k_buffer,
|
200
204
|
att_out,
|
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
|
|
254
258
|
)
|
255
259
|
|
256
260
|
|
257
|
-
def
|
261
|
+
def _decode_softmax_reducev_fwd(
|
258
262
|
logics,
|
259
263
|
v_buffer,
|
260
264
|
o,
|
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
|
|
292
296
|
)
|
293
297
|
|
294
298
|
|
295
|
-
def
|
299
|
+
def decode_attention_fwd(
|
296
300
|
q,
|
297
301
|
k_buffer,
|
298
302
|
v_buffer,
|
@@ -312,7 +316,7 @@ def token_attention_fwd(
|
|
312
316
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
313
317
|
)
|
314
318
|
|
315
|
-
|
319
|
+
_decode_att_m_fwd(
|
316
320
|
q,
|
317
321
|
k_buffer,
|
318
322
|
att_m,
|
@@ -324,7 +328,7 @@ def token_attention_fwd(
|
|
324
328
|
sm_scale,
|
325
329
|
logit_cap,
|
326
330
|
)
|
327
|
-
|
331
|
+
_decode_softmax_reducev_fwd(
|
328
332
|
att_m,
|
329
333
|
v_buffer,
|
330
334
|
o,
|
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for prefill.
|
18
|
+
It supporst page size = 1 and prefill with KV cache (i.e. extend).
|
19
|
+
"""
|
20
|
+
|
16
21
|
import torch
|
17
22
|
import triton
|
18
23
|
import triton.language as tl
|
19
24
|
|
20
|
-
from sglang.srt.layers.
|
25
|
+
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
21
26
|
|
22
27
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
23
28
|
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""Fused operators for normalization layers."""
|
17
|
+
|
18
|
+
from typing import Optional, Tuple, Union
|
19
|
+
|
20
|
+
import torch
|
21
|
+
import torch.nn as nn
|
22
|
+
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
23
|
+
from vllm.model_executor.custom_op import CustomOp
|
24
|
+
|
25
|
+
|
26
|
+
class RMSNorm(CustomOp):
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
hidden_size: int,
|
30
|
+
eps: float = 1e-6,
|
31
|
+
) -> None:
|
32
|
+
super().__init__()
|
33
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
34
|
+
self.variance_epsilon = eps
|
35
|
+
|
36
|
+
def forward_cuda(
|
37
|
+
self,
|
38
|
+
x: torch.Tensor,
|
39
|
+
residual: Optional[torch.Tensor] = None,
|
40
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
41
|
+
|
42
|
+
if residual is not None:
|
43
|
+
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
44
|
+
return x, residual
|
45
|
+
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
|
46
|
+
return out
|
47
|
+
|
48
|
+
def forward_native(
|
49
|
+
self,
|
50
|
+
x: torch.Tensor,
|
51
|
+
residual: Optional[torch.Tensor] = None,
|
52
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
53
|
+
orig_dtype = x.dtype
|
54
|
+
x = x.to(torch.float32)
|
55
|
+
if residual is not None:
|
56
|
+
x = x + residual.to(torch.float32)
|
57
|
+
residual = x.to(orig_dtype)
|
58
|
+
|
59
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
60
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
61
|
+
x = x.to(orig_dtype) * self.weight
|
62
|
+
if residual is None:
|
63
|
+
return x
|
64
|
+
else:
|
65
|
+
return x, residual
|
@@ -208,6 +208,11 @@ class LogitsProcessor(nn.Module):
|
|
208
208
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
209
209
|
all_logits = all_logits[:, : self.config.vocab_size].float()
|
210
210
|
|
211
|
+
if hasattr(self.config, "final_logit_softcapping"):
|
212
|
+
all_logits /= self.config.final_logit_softcapping
|
213
|
+
all_logits = torch.tanh(all_logits)
|
214
|
+
all_logits *= self.config.final_logit_softcapping
|
215
|
+
|
211
216
|
all_logprobs = all_logits
|
212
217
|
del all_logits, hidden_states
|
213
218
|
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
|
3
|
+
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from enum import IntEnum
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
from sglang.srt.model_executor.model_runner import InputMetadata
|
11
|
+
|
12
|
+
|
13
|
+
class PoolingType(IntEnum):
|
14
|
+
LAST = 0
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class EmbeddingPoolerOutput:
|
19
|
+
embeddings: torch.Tensor
|
20
|
+
|
21
|
+
|
22
|
+
class Pooler(nn.Module):
|
23
|
+
"""A layer that pools specific information from hidden states.
|
24
|
+
This layer does the following:
|
25
|
+
1. Extracts specific tokens or aggregates data based on pooling method.
|
26
|
+
2. Normalizes output if specified.
|
27
|
+
3. Returns structured results as `PoolerOutput`.
|
28
|
+
Attributes:
|
29
|
+
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
|
30
|
+
normalize: Whether to normalize the pooled data.
|
31
|
+
"""
|
32
|
+
|
33
|
+
def __init__(self, pooling_type: PoolingType, normalize: bool):
|
34
|
+
super().__init__()
|
35
|
+
self.pooling_type = pooling_type
|
36
|
+
self.normalize = normalize
|
37
|
+
|
38
|
+
def forward(
|
39
|
+
self, hidden_states: torch.Tensor, input_metadata: InputMetadata
|
40
|
+
) -> EmbeddingPoolerOutput:
|
41
|
+
if self.pooling_type == PoolingType.LAST:
|
42
|
+
last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
|
43
|
+
pooled_data = hidden_states[last_token_indices]
|
44
|
+
else:
|
45
|
+
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
|
46
|
+
|
47
|
+
if self.normalize:
|
48
|
+
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
|
49
|
+
|
50
|
+
return EmbeddingPoolerOutput(embeddings=pooled_data)
|
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
+
"""
|
17
|
+
Memory-efficient attention for prefill.
|
18
|
+
It supporst page size = 1.
|
19
|
+
"""
|
20
|
+
|
16
21
|
# Adapted from
|
17
22
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
18
23
|
import torch
|
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
|
|
20
20
|
from torch import nn
|
21
21
|
|
22
22
|
from sglang.global_config import global_config
|
23
|
+
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
23
24
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
24
|
-
from sglang.srt.layers.token_attention import token_attention_fwd
|
25
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
26
26
|
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
27
27
|
|
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
|
|
95
95
|
o = torch.empty_like(q)
|
96
96
|
self.store_kv_cache(k, v, input_metadata)
|
97
97
|
|
98
|
-
|
98
|
+
decode_attention_fwd(
|
99
99
|
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
100
100
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
101
101
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
@@ -25,10 +25,14 @@ import zmq
|
|
25
25
|
import zmq.asyncio
|
26
26
|
|
27
27
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
28
|
-
from sglang.srt.managers.io_struct import
|
28
|
+
from sglang.srt.managers.io_struct import (
|
29
|
+
BatchEmbeddingOut,
|
30
|
+
BatchStrOut,
|
31
|
+
BatchTokenIDOut,
|
32
|
+
)
|
29
33
|
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
30
34
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
31
|
-
from sglang.utils import find_printable_text, get_exception_traceback
|
35
|
+
from sglang.utils import find_printable_text, get_exception_traceback
|
32
36
|
|
33
37
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
34
38
|
|
@@ -55,20 +59,40 @@ class DetokenizerManager:
|
|
55
59
|
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
56
60
|
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
57
61
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
62
|
+
if server_args.skip_tokenizer_init:
|
63
|
+
self.tokenizer = None
|
64
|
+
else:
|
65
|
+
self.tokenizer = get_tokenizer(
|
66
|
+
server_args.tokenizer_path,
|
67
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
68
|
+
trust_remote_code=server_args.trust_remote_code,
|
69
|
+
)
|
63
70
|
|
64
71
|
self.decode_status = {}
|
65
72
|
|
66
73
|
async def handle_loop(self):
|
67
74
|
while True:
|
68
75
|
recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
|
76
|
+
|
77
|
+
if isinstance(recv_obj, BatchEmbeddingOut):
|
78
|
+
self.send_to_tokenizer.send_pyobj(
|
79
|
+
BatchEmbeddingOut(
|
80
|
+
rids=recv_obj.rids,
|
81
|
+
embeddings=recv_obj.embeddings,
|
82
|
+
meta_info=recv_obj.meta_info,
|
83
|
+
finished_reason=recv_obj.finished_reason,
|
84
|
+
)
|
85
|
+
)
|
86
|
+
continue
|
87
|
+
|
69
88
|
assert isinstance(recv_obj, BatchTokenIDOut)
|
70
89
|
bs = len(recv_obj.rids)
|
71
90
|
|
91
|
+
if self.tokenizer is None:
|
92
|
+
# Send BatchTokenIDOut if no tokenizer init'ed.
|
93
|
+
self.send_to_tokenizer.send_pyobj(recv_obj)
|
94
|
+
continue
|
95
|
+
|
72
96
|
# Initialize decode status
|
73
97
|
read_ids, surr_ids = [], []
|
74
98
|
for i in range(bs):
|
@@ -140,8 +164,6 @@ def start_detokenizer_process(
|
|
140
164
|
port_args: PortArgs,
|
141
165
|
pipe_writer,
|
142
166
|
):
|
143
|
-
graceful_registry(inspect.currentframe().f_code.co_name)
|
144
|
-
|
145
167
|
try:
|
146
168
|
manager = DetokenizerManager(server_args, port_args)
|
147
169
|
except Exception:
|