sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py
CHANGED
@@ -8,19 +8,21 @@ from typing import List, Optional, Union
|
|
8
8
|
from sglang.global_config import global_config
|
9
9
|
from sglang.lang.choices import ChoicesSamplingMethod
|
10
10
|
|
11
|
-
REGEX_INT = r"[-+]?[0-9]+"
|
12
|
-
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
11
|
+
REGEX_INT = r"[-+]?[0-9]+[ \n]*"
|
12
|
+
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
|
13
13
|
REGEX_BOOL = r"(True|False)"
|
14
|
-
|
14
|
+
REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
15
15
|
|
16
16
|
|
17
17
|
@dataclasses.dataclass
|
18
18
|
class SglSamplingParams:
|
19
19
|
max_new_tokens: int = 128
|
20
20
|
stop: Union[str, List[str]] = ()
|
21
|
+
stop_token_ids: Optional[List[int]] = ()
|
21
22
|
temperature: float = 1.0
|
22
23
|
top_p: float = 1.0
|
23
24
|
top_k: int = -1 # -1 means disable
|
25
|
+
min_p: float = 0.0
|
24
26
|
frequency_penalty: float = 0.0
|
25
27
|
presence_penalty: float = 0.0
|
26
28
|
ignore_eos: bool = False
|
@@ -37,9 +39,11 @@ class SglSamplingParams:
|
|
37
39
|
return SglSamplingParams(
|
38
40
|
self.max_new_tokens,
|
39
41
|
self.stop,
|
42
|
+
self.stop_token_ids,
|
40
43
|
self.temperature,
|
41
44
|
self.top_p,
|
42
45
|
self.top_k,
|
46
|
+
self.min_p,
|
43
47
|
self.frequency_penalty,
|
44
48
|
self.presence_penalty,
|
45
49
|
self.ignore_eos,
|
@@ -108,9 +112,11 @@ class SglSamplingParams:
|
|
108
112
|
return {
|
109
113
|
"max_new_tokens": self.max_new_tokens,
|
110
114
|
"stop": self.stop,
|
115
|
+
"stop_token_ids": self.stop_token_ids,
|
111
116
|
"temperature": self.temperature,
|
112
117
|
"top_p": self.top_p,
|
113
118
|
"top_k": self.top_k,
|
119
|
+
"min_p": self.min_p,
|
114
120
|
"frequency_penalty": self.frequency_penalty,
|
115
121
|
"presence_penalty": self.presence_penalty,
|
116
122
|
"ignore_eos": self.ignore_eos,
|
@@ -141,10 +147,12 @@ class SglFunction:
|
|
141
147
|
self,
|
142
148
|
*args,
|
143
149
|
max_new_tokens: int = 128,
|
144
|
-
stop: Union[str, List[str]] =
|
150
|
+
stop: Union[str, List[str]] = [],
|
151
|
+
stop_token_ids: Optional[List[int]] = [],
|
145
152
|
temperature: float = 1.0,
|
146
153
|
top_p: float = 1.0,
|
147
154
|
top_k: int = -1,
|
155
|
+
min_p: float = 0.0,
|
148
156
|
frequency_penalty: float = 0.0,
|
149
157
|
presence_penalty: float = 0.0,
|
150
158
|
ignore_eos: bool = False,
|
@@ -161,9 +169,11 @@ class SglFunction:
|
|
161
169
|
default_sampling_para = SglSamplingParams(
|
162
170
|
max_new_tokens=max_new_tokens,
|
163
171
|
stop=stop,
|
172
|
+
stop_token_ids=stop_token_ids,
|
164
173
|
temperature=temperature,
|
165
174
|
top_p=top_p,
|
166
175
|
top_k=top_k,
|
176
|
+
min_p=min_p,
|
167
177
|
frequency_penalty=frequency_penalty,
|
168
178
|
presence_penalty=presence_penalty,
|
169
179
|
ignore_eos=ignore_eos,
|
@@ -181,9 +191,11 @@ class SglFunction:
|
|
181
191
|
*,
|
182
192
|
max_new_tokens: int = 128,
|
183
193
|
stop: Union[str, List[str]] = (),
|
194
|
+
stop_token_ids: Optional[List[int]] = [],
|
184
195
|
temperature: float = 1.0,
|
185
196
|
top_p: float = 1.0,
|
186
197
|
top_k: int = -1,
|
198
|
+
min_p: float = 0.0,
|
187
199
|
frequency_penalty: float = 0.0,
|
188
200
|
presence_penalty: float = 0.0,
|
189
201
|
ignore_eos: bool = False,
|
@@ -218,9 +230,11 @@ class SglFunction:
|
|
218
230
|
default_sampling_para = SglSamplingParams(
|
219
231
|
max_new_tokens=max_new_tokens,
|
220
232
|
stop=stop,
|
233
|
+
stop_token_ids=stop_token_ids,
|
221
234
|
temperature=temperature,
|
222
235
|
top_p=top_p,
|
223
236
|
top_k=top_k,
|
237
|
+
min_p=min_p,
|
224
238
|
frequency_penalty=frequency_penalty,
|
225
239
|
presence_penalty=presence_penalty,
|
226
240
|
ignore_eos=ignore_eos,
|
@@ -397,9 +411,11 @@ class SglGen(SglExpr):
|
|
397
411
|
name: Optional[str] = None,
|
398
412
|
max_new_tokens: Optional[int] = None,
|
399
413
|
stop: Optional[Union[str, List[str]]] = None,
|
414
|
+
stop_token_ids: Optional[List[int]] = None,
|
400
415
|
temperature: Optional[float] = None,
|
401
416
|
top_p: Optional[float] = None,
|
402
417
|
top_k: Optional[int] = None,
|
418
|
+
min_p: Optional[float] = None,
|
403
419
|
frequency_penalty: Optional[float] = None,
|
404
420
|
presence_penalty: Optional[float] = None,
|
405
421
|
ignore_eos: Optional[bool] = None,
|
@@ -416,9 +432,11 @@ class SglGen(SglExpr):
|
|
416
432
|
self.sampling_params = SglSamplingParams(
|
417
433
|
max_new_tokens=max_new_tokens,
|
418
434
|
stop=stop,
|
435
|
+
stop_token_ids=stop_token_ids,
|
419
436
|
temperature=temperature,
|
420
437
|
top_p=top_p,
|
421
438
|
top_k=top_k,
|
439
|
+
min_p=min_p,
|
422
440
|
frequency_penalty=frequency_penalty,
|
423
441
|
presence_penalty=presence_penalty,
|
424
442
|
ignore_eos=ignore_eos,
|
sglang/launch_server.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
"""Launch the inference server."""
|
2
2
|
|
3
3
|
import argparse
|
4
|
+
import os
|
4
5
|
|
5
6
|
from sglang.srt.server import launch_server
|
6
7
|
from sglang.srt.server_args import ServerArgs
|
8
|
+
from sglang.srt.utils import kill_child_process
|
7
9
|
|
8
10
|
if __name__ == "__main__":
|
9
11
|
parser = argparse.ArgumentParser()
|
@@ -11,4 +13,9 @@ if __name__ == "__main__":
|
|
11
13
|
args = parser.parse_args()
|
12
14
|
server_args = ServerArgs.from_cli_args(args)
|
13
15
|
|
14
|
-
|
16
|
+
try:
|
17
|
+
launch_server(server_args)
|
18
|
+
except Exception as e:
|
19
|
+
raise e
|
20
|
+
finally:
|
21
|
+
kill_child_process(os.getpid(), including_parent=False)
|
@@ -62,16 +62,22 @@ class JumpForwardMap:
|
|
62
62
|
id_to_symbol.setdefault(id_, []).append(symbol)
|
63
63
|
|
64
64
|
transitions = fsm_info.transitions
|
65
|
+
|
65
66
|
outgoings_ct = defaultdict(int)
|
66
|
-
|
67
|
+
# NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
|
68
|
+
for s in fsm_info.finals:
|
69
|
+
outgoings_ct[s] = 1
|
67
70
|
|
71
|
+
state_to_jump_forward = {}
|
68
72
|
for (state, id_), next_state in transitions.items():
|
69
73
|
if id_ == fsm_info.alphabet_anything_value:
|
74
|
+
# Arbitrarily symbol cannot be recognized as jump forward
|
70
75
|
continue
|
76
|
+
|
71
77
|
symbols = id_to_symbol[id_]
|
72
78
|
for c in symbols:
|
73
79
|
if len(c) > 1:
|
74
|
-
# Skip byte level transitions
|
80
|
+
# Skip byte level transitions like c = "5E"
|
75
81
|
continue
|
76
82
|
|
77
83
|
outgoings_ct[state] += 1
|
@@ -87,6 +93,9 @@ class JumpForwardMap:
|
|
87
93
|
|
88
94
|
# Process the byte level jump forward
|
89
95
|
outgoings_ct = defaultdict(int)
|
96
|
+
for s in fsm_info.finals:
|
97
|
+
outgoings_ct[s] = 1
|
98
|
+
|
90
99
|
for (state, id_), next_state in transitions.items():
|
91
100
|
if id_ == fsm_info.alphabet_anything_value:
|
92
101
|
continue
|
@@ -177,3 +186,5 @@ if __name__ == "__main__":
|
|
177
186
|
test_main(r"霍格沃茨特快列车|霍比特人比尔博")
|
178
187
|
# 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
|
179
188
|
# 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
|
189
|
+
|
190
|
+
test_main(r"[-+]?[0-9]+[ ]*")
|
sglang/srt/conversation.py
CHANGED
@@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum):
|
|
34
34
|
NO_COLON_TWO = auto()
|
35
35
|
ADD_NEW_LINE_SINGLE = auto()
|
36
36
|
LLAMA2 = auto()
|
37
|
+
LLAMA3 = auto()
|
37
38
|
CHATGLM = auto()
|
38
39
|
CHATML = auto()
|
39
40
|
CHATINTERN = auto()
|
@@ -137,6 +138,20 @@ class Conversation:
|
|
137
138
|
else:
|
138
139
|
ret += role + ":"
|
139
140
|
return ret
|
141
|
+
elif self.sep_style == SeparatorStyle.LLAMA3:
|
142
|
+
ret = "<|begin_of_text|>"
|
143
|
+
if self.system_message:
|
144
|
+
ret += system_prompt
|
145
|
+
else:
|
146
|
+
ret += ""
|
147
|
+
for i, (role, message) in enumerate(self.messages):
|
148
|
+
if message:
|
149
|
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
150
|
+
ret += f"{message.strip()}<|eot_id|>"
|
151
|
+
else:
|
152
|
+
ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
|
153
|
+
# print(ret)
|
154
|
+
return ret
|
140
155
|
elif self.sep_style == SeparatorStyle.LLAMA2:
|
141
156
|
seps = [self.sep, self.sep2]
|
142
157
|
if self.system_message:
|
@@ -379,12 +394,23 @@ def generate_chat_conv(
|
|
379
394
|
conv.append_message(conv.roles[0], message.content)
|
380
395
|
else:
|
381
396
|
real_content = ""
|
397
|
+
# calculate number of image_url
|
398
|
+
num_image_url = 0
|
399
|
+
for content in message.content:
|
400
|
+
if content.type == "image_url":
|
401
|
+
num_image_url += 1
|
402
|
+
if num_image_url > 1:
|
403
|
+
image_token = "<image>"
|
404
|
+
else:
|
405
|
+
image_token = "<image>\n"
|
382
406
|
for content in message.content:
|
383
407
|
if content.type == "text":
|
408
|
+
if num_image_url > 16:
|
409
|
+
real_content += "\n" # for video
|
384
410
|
real_content += content.text
|
385
411
|
elif content.type == "image_url":
|
386
412
|
# NOTE: Only works for llava
|
387
|
-
real_content +=
|
413
|
+
real_content += image_token
|
388
414
|
conv.append_image(content.image_url.url)
|
389
415
|
conv.append_message(conv.roles[0], real_content)
|
390
416
|
elif msg_role == "assistant":
|
@@ -425,6 +451,18 @@ register_conv_template(
|
|
425
451
|
)
|
426
452
|
)
|
427
453
|
|
454
|
+
register_conv_template(
|
455
|
+
Conversation(
|
456
|
+
name="chatml-llava",
|
457
|
+
system_template="<|im_start|>system\n{system_message}",
|
458
|
+
system_message="You are a helpful assistant.",
|
459
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
460
|
+
sep_style=SeparatorStyle.CHATML,
|
461
|
+
sep="<|im_end|>",
|
462
|
+
stop_str=["<|endoftext|>", "<|im_end|>"],
|
463
|
+
)
|
464
|
+
)
|
465
|
+
|
428
466
|
register_conv_template(
|
429
467
|
Conversation(
|
430
468
|
name="vicuna_v1.1",
|
@@ -437,6 +475,17 @@ register_conv_template(
|
|
437
475
|
)
|
438
476
|
)
|
439
477
|
|
478
|
+
register_conv_template(
|
479
|
+
Conversation(
|
480
|
+
name="llava_llama_3",
|
481
|
+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
482
|
+
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
483
|
+
roles=("user", "assistant"),
|
484
|
+
sep_style=SeparatorStyle.LLAMA3,
|
485
|
+
sep="",
|
486
|
+
stop_str=["<|end_of_text|>", "<|eot_id|>"],
|
487
|
+
)
|
488
|
+
)
|
440
489
|
# Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442
|
441
490
|
register_conv_template(
|
442
491
|
Conversation(
|
@@ -30,14 +30,19 @@ from transformers import (
|
|
30
30
|
PreTrainedTokenizer,
|
31
31
|
PreTrainedTokenizerFast,
|
32
32
|
)
|
33
|
-
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
34
33
|
|
35
|
-
|
34
|
+
try:
|
35
|
+
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
36
|
+
|
37
|
+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
38
|
+
ChatGLMConfig.model_type: ChatGLMConfig,
|
39
|
+
DbrxConfig.model_type: DbrxConfig,
|
40
|
+
}
|
41
|
+
except ImportError:
|
42
|
+
# We want this file to run without vllm dependency
|
43
|
+
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
|
36
44
|
|
37
|
-
|
38
|
-
ChatGLMConfig.model_type: ChatGLMConfig,
|
39
|
-
DbrxConfig.model_type: DbrxConfig,
|
40
|
-
}
|
45
|
+
from sglang.srt.utils import is_multimodal_model
|
41
46
|
|
42
47
|
|
43
48
|
def download_from_hf(model_path: str):
|
@@ -137,18 +142,6 @@ def get_tokenizer(
|
|
137
142
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
138
143
|
kwargs["use_fast"] = False
|
139
144
|
|
140
|
-
if (
|
141
|
-
"llama" in tokenizer_name.lower()
|
142
|
-
and kwargs.get("use_fast", True)
|
143
|
-
and tokenizer_name != _FAST_LLAMA_TOKENIZER
|
144
|
-
):
|
145
|
-
pass
|
146
|
-
# warnings.warn(
|
147
|
-
# "For some LLaMA V1 models, initializing the fast tokenizer may "
|
148
|
-
# "take a long time. To reduce the initialization time, consider "
|
149
|
-
# f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original "
|
150
|
-
# "tokenizer."
|
151
|
-
# )
|
152
145
|
try:
|
153
146
|
tokenizer = AutoTokenizer.from_pretrained(
|
154
147
|
tokenizer_name,
|
@@ -229,6 +222,8 @@ class TiktokenTokenizer:
|
|
229
222
|
}
|
230
223
|
assert tok_dict["word_split"] == "V1"
|
231
224
|
|
225
|
+
default_allowed_special = None
|
226
|
+
|
232
227
|
kwargs = {
|
233
228
|
"name": name,
|
234
229
|
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
@@ -242,14 +237,18 @@ class TiktokenTokenizer:
|
|
242
237
|
for bytes_list in tok_dict["default_allowed_special"]
|
243
238
|
]
|
244
239
|
)
|
245
|
-
else:
|
246
|
-
default_allowed_special = None
|
247
240
|
if "vocab_size" in tok_dict:
|
248
241
|
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
249
242
|
|
243
|
+
PAD = "<|pad|>"
|
244
|
+
EOS = "<|eos|>"
|
245
|
+
SEP = "<|separator|>"
|
246
|
+
|
247
|
+
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
248
|
+
|
250
249
|
tokenizer = tiktoken.Encoding(**kwargs)
|
251
250
|
tokenizer._default_allowed_special = default_allowed_special or set()
|
252
|
-
tokenizer.
|
251
|
+
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
253
252
|
|
254
253
|
def encode_patched(
|
255
254
|
self,
|
@@ -266,14 +265,14 @@ class TiktokenTokenizer:
|
|
266
265
|
self,
|
267
266
|
text,
|
268
267
|
allowed_special=allowed_special,
|
269
|
-
disallowed_special=
|
268
|
+
disallowed_special=(),
|
270
269
|
)
|
271
270
|
|
272
271
|
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
273
272
|
|
274
273
|
# Convert to HF interface
|
275
274
|
self.tokenizer = tokenizer
|
276
|
-
self.eos_token_id = tokenizer._special_tokens[
|
275
|
+
self.eos_token_id = tokenizer._special_tokens[EOS]
|
277
276
|
self.vocab_size = tokenizer.n_vocab
|
278
277
|
self.chat_template = Template(
|
279
278
|
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
sglang/srt/layers/activation.py
CHANGED
@@ -14,20 +14,42 @@ limitations under the License.
|
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
16
|
import torch
|
17
|
-
import torch.nn as nn
|
18
17
|
import torch.nn.functional as F
|
19
|
-
from flashinfer.activation import silu_and_mul
|
18
|
+
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
20
19
|
from vllm.model_executor.custom_op import CustomOp
|
21
20
|
|
22
21
|
|
23
22
|
class SiluAndMul(CustomOp):
|
23
|
+
def __init__(self, **kwargs):
|
24
|
+
super().__init__()
|
25
|
+
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
|
26
|
+
|
24
27
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
25
28
|
d = x.shape[-1] // 2
|
26
29
|
return F.silu(x[..., :d]) * x[..., d:]
|
27
30
|
|
28
31
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
32
|
+
if self.is_lower_sm80:
|
33
|
+
return self.forward_native(x)
|
34
|
+
|
29
35
|
d = x.shape[-1] // 2
|
30
36
|
output_shape = x.shape[:-1] + (d,)
|
31
37
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
32
38
|
silu_and_mul(x, out)
|
33
39
|
return out
|
40
|
+
|
41
|
+
|
42
|
+
class GeluAndMul(CustomOp):
|
43
|
+
def __init__(self, **kwargs):
|
44
|
+
super().__init__()
|
45
|
+
|
46
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
47
|
+
d = x.shape[-1] // 2
|
48
|
+
return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
|
49
|
+
|
50
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
51
|
+
d = x.shape[-1] // 2
|
52
|
+
output_shape = x.shape[:-1] + (d,)
|
53
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
54
|
+
gelu_tanh_and_mul(x, out)
|
55
|
+
return out
|