sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/hf_transformers_utils.py +0 -149
- sglang/srt/layers/activation.py +93 -11
- sglang/srt/layers/layernorm.py +47 -4
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +15 -68
- sglang/srt/managers/io_struct.py +5 -4
- sglang/srt/managers/schedule_batch.py +20 -25
- sglang/srt/managers/tokenizer_manager.py +74 -61
- sglang/srt/managers/tp_worker.py +49 -43
- sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- sglang/srt/model_executor/forward_batch_info.py +9 -26
- sglang/srt/model_executor/model_runner.py +20 -17
- sglang/srt/models/chatglm.py +13 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/gemma.py +3 -7
- sglang/srt/models/gemma2.py +2 -56
- sglang/srt/models/gpt_bigcode.py +2 -6
- sglang/srt/models/grok.py +10 -8
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama2.py +6 -11
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/mixtral.py +1 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +2 -5
- sglang/srt/models/qwen2.py +5 -10
- sglang/srt/models/qwen2_moe.py +21 -24
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +85 -4
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -74
- sglang/srt/sampling/sampling_params.py +4 -0
- sglang/srt/server.py +11 -4
- sglang/srt/utils.py +18 -33
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,26 @@
|
|
1
|
+
"""Launch the inference server for Llava-video model."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
|
5
|
+
from sglang.srt.server import ServerArgs, launch_server
|
6
|
+
|
7
|
+
if __name__ == "__main__":
|
8
|
+
parser = argparse.ArgumentParser()
|
9
|
+
ServerArgs.add_cli_args(parser)
|
10
|
+
args = parser.parse_args()
|
11
|
+
server_args = ServerArgs.from_cli_args(args)
|
12
|
+
|
13
|
+
model_overide_args = {}
|
14
|
+
model_overide_args["mm_spatial_pool_stride"] = 2
|
15
|
+
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
16
|
+
model_overide_args["num_frames"] = 16
|
17
|
+
model_overide_args["model_type"] = "llavavid"
|
18
|
+
if model_overide_args["num_frames"] == 32:
|
19
|
+
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
20
|
+
model_overide_args["max_sequence_length"] = 4096 * 2
|
21
|
+
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
22
|
+
model_overide_args["model_max_length"] = 4096 * 2
|
23
|
+
if "34b" in args.model_path.lower():
|
24
|
+
model_overide_args["image_token_index"] = 64002
|
25
|
+
|
26
|
+
launch_server(server_args, model_overide_args, None)
|
@@ -15,6 +15,8 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Cache for the compressed finite state machine."""
|
17
17
|
|
18
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
19
|
+
|
18
20
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
19
21
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
20
22
|
|
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
|
|
26
28
|
tokenizer_args_dict,
|
27
29
|
enable=True,
|
28
30
|
skip_tokenizer_init=False,
|
31
|
+
json_schema_mode=False,
|
29
32
|
):
|
30
33
|
super().__init__(enable=enable)
|
31
34
|
|
35
|
+
self.json_schema_mode = json_schema_mode
|
36
|
+
|
32
37
|
if (
|
33
38
|
skip_tokenizer_init
|
34
39
|
or tokenizer_path.endswith(".json")
|
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
|
|
72
77
|
tokenizer_path, **tokenizer_args_dict
|
73
78
|
)
|
74
79
|
|
75
|
-
def init_value(self,
|
76
|
-
|
80
|
+
def init_value(self, value):
|
81
|
+
if self.json_schema_mode:
|
82
|
+
regex = build_regex_from_schema(value)
|
83
|
+
return RegexGuide(regex, self.outlines_tokenizer), regex
|
84
|
+
else:
|
85
|
+
return RegexGuide(value, self.outlines_tokenizer)
|
@@ -119,24 +119,7 @@ def get_tokenizer(
|
|
119
119
|
tokenizer_revision: Optional[str] = None,
|
120
120
|
**kwargs,
|
121
121
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
122
|
-
if tokenizer_name.endswith(".json"):
|
123
|
-
return TiktokenTokenizer(tokenizer_name)
|
124
|
-
|
125
|
-
if tokenizer_name.endswith(".model"):
|
126
|
-
return SentencePieceTokenizer(tokenizer_name)
|
127
|
-
|
128
122
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
129
|
-
if is_multimodal_model(tokenizer_name):
|
130
|
-
processor = get_processor(
|
131
|
-
tokenizer_name,
|
132
|
-
*args,
|
133
|
-
trust_remote_code=trust_remote_code,
|
134
|
-
tokenizer_revision=tokenizer_revision,
|
135
|
-
**kwargs,
|
136
|
-
)
|
137
|
-
tokenizer = processor.tokenizer
|
138
|
-
return tokenizer
|
139
|
-
|
140
123
|
if tokenizer_mode == "slow":
|
141
124
|
if kwargs.get("use_fast", False):
|
142
125
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
@@ -199,135 +182,3 @@ def get_processor(
|
|
199
182
|
**kwargs,
|
200
183
|
)
|
201
184
|
return processor
|
202
|
-
|
203
|
-
|
204
|
-
class TiktokenTokenizer:
|
205
|
-
def __init__(self, tokenizer_path):
|
206
|
-
import tiktoken
|
207
|
-
from jinja2 import Template
|
208
|
-
|
209
|
-
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
210
|
-
|
211
|
-
# Read JSON
|
212
|
-
name = "tmp-json"
|
213
|
-
with open(tokenizer_path, "rb") as fin:
|
214
|
-
tok_dict = json.load(fin)
|
215
|
-
|
216
|
-
mergeable_ranks = {
|
217
|
-
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
218
|
-
}
|
219
|
-
special_tokens = {
|
220
|
-
bytes(item["bytes"]).decode(): item["token"]
|
221
|
-
for item in tok_dict["special_tokens"]
|
222
|
-
}
|
223
|
-
assert tok_dict["word_split"] == "V1"
|
224
|
-
|
225
|
-
default_allowed_special = None
|
226
|
-
|
227
|
-
kwargs = {
|
228
|
-
"name": name,
|
229
|
-
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
230
|
-
"mergeable_ranks": mergeable_ranks,
|
231
|
-
"special_tokens": special_tokens,
|
232
|
-
}
|
233
|
-
if "default_allowed_special" in tok_dict:
|
234
|
-
default_allowed_special = set(
|
235
|
-
[
|
236
|
-
bytes(bytes_list).decode()
|
237
|
-
for bytes_list in tok_dict["default_allowed_special"]
|
238
|
-
]
|
239
|
-
)
|
240
|
-
if "vocab_size" in tok_dict:
|
241
|
-
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
242
|
-
|
243
|
-
PAD = "<|pad|>"
|
244
|
-
EOS = "<|eos|>"
|
245
|
-
SEP = "<|separator|>"
|
246
|
-
|
247
|
-
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
248
|
-
|
249
|
-
tokenizer = tiktoken.Encoding(**kwargs)
|
250
|
-
tokenizer._default_allowed_special = default_allowed_special or set()
|
251
|
-
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
252
|
-
|
253
|
-
def encode_patched(
|
254
|
-
self,
|
255
|
-
text: str,
|
256
|
-
*,
|
257
|
-
allowed_special: Union[
|
258
|
-
Literal["all"], AbstractSet[str]
|
259
|
-
] = set(), # noqa: B006
|
260
|
-
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
261
|
-
) -> List[int]:
|
262
|
-
if isinstance(allowed_special, set):
|
263
|
-
allowed_special |= self._default_allowed_special
|
264
|
-
return tiktoken.Encoding.encode(
|
265
|
-
self,
|
266
|
-
text,
|
267
|
-
allowed_special=allowed_special,
|
268
|
-
disallowed_special=(),
|
269
|
-
)
|
270
|
-
|
271
|
-
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
272
|
-
|
273
|
-
# Convert to HF interface
|
274
|
-
self.tokenizer = tokenizer
|
275
|
-
self.eos_token_id = tokenizer._special_tokens[EOS]
|
276
|
-
self.vocab_size = tokenizer.n_vocab
|
277
|
-
self.chat_template = Template(
|
278
|
-
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
279
|
-
)
|
280
|
-
|
281
|
-
def encode(self, x, add_special_tokens=False):
|
282
|
-
return self.tokenizer.encode(x)
|
283
|
-
|
284
|
-
def decode(self, x):
|
285
|
-
return self.tokenizer.decode(x)
|
286
|
-
|
287
|
-
def batch_decode(
|
288
|
-
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
289
|
-
):
|
290
|
-
if isinstance(batch[0], int):
|
291
|
-
batch = [[x] for x in batch]
|
292
|
-
return self.tokenizer.decode_batch(batch)
|
293
|
-
|
294
|
-
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
295
|
-
ret = self.chat_template.render(
|
296
|
-
messages=messages, add_generation_prompt=add_generation_prompt
|
297
|
-
)
|
298
|
-
return self.encode(ret) if tokenize else ret
|
299
|
-
|
300
|
-
|
301
|
-
class SentencePieceTokenizer:
|
302
|
-
def __init__(self, tokenizer_path):
|
303
|
-
import sentencepiece as spm
|
304
|
-
from jinja2 import Template
|
305
|
-
|
306
|
-
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
|
307
|
-
|
308
|
-
# Convert to HF interface
|
309
|
-
self.tokenizer = tokenizer
|
310
|
-
self.eos_token_id = tokenizer.eos_id()
|
311
|
-
self.vocab_size = tokenizer.vocab_size()
|
312
|
-
self.chat_template = Template(
|
313
|
-
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
314
|
-
)
|
315
|
-
|
316
|
-
def encode(self, x, add_special_tokens=False):
|
317
|
-
return self.tokenizer.encode(x)
|
318
|
-
|
319
|
-
def decode(self, x):
|
320
|
-
return self.tokenizer.decode(x)
|
321
|
-
|
322
|
-
def batch_decode(
|
323
|
-
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
324
|
-
):
|
325
|
-
if isinstance(batch[0], int):
|
326
|
-
batch = [[x] for x in batch]
|
327
|
-
return self.tokenizer.decode(batch)
|
328
|
-
|
329
|
-
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
330
|
-
ret = self.chat_template.render(
|
331
|
-
messages=messages, add_generation_prompt=add_generation_prompt
|
332
|
-
)
|
333
|
-
return self.encode(ret) if tokenize else ret
|
sglang/srt/layers/activation.py
CHANGED
@@ -13,25 +13,28 @@ limitations under the License.
|
|
13
13
|
|
14
14
|
"""Fused operators for activation layers."""
|
15
15
|
|
16
|
+
from typing import Optional
|
17
|
+
|
16
18
|
import torch
|
19
|
+
import torch.nn as nn
|
17
20
|
import torch.nn.functional as F
|
18
|
-
from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
|
21
|
+
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
22
|
+
from vllm.distributed import (
|
23
|
+
divide,
|
24
|
+
get_tensor_model_parallel_rank,
|
25
|
+
get_tensor_model_parallel_world_size,
|
26
|
+
)
|
19
27
|
from vllm.model_executor.custom_op import CustomOp
|
28
|
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
29
|
+
from vllm.model_executor.utils import set_weight_attrs
|
20
30
|
|
21
31
|
|
22
32
|
class SiluAndMul(CustomOp):
|
23
|
-
def __init__(self, **kwargs):
|
24
|
-
super().__init__()
|
25
|
-
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
|
26
|
-
|
27
33
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
28
34
|
d = x.shape[-1] // 2
|
29
35
|
return F.silu(x[..., :d]) * x[..., d:]
|
30
36
|
|
31
37
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
32
|
-
if self.is_lower_sm80:
|
33
|
-
return self.forward_native(x)
|
34
|
-
|
35
38
|
d = x.shape[-1] // 2
|
36
39
|
output_shape = x.shape[:-1] + (d,)
|
37
40
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
@@ -40,16 +43,95 @@ class SiluAndMul(CustomOp):
|
|
40
43
|
|
41
44
|
|
42
45
|
class GeluAndMul(CustomOp):
|
43
|
-
def __init__(self,
|
46
|
+
def __init__(self, approximate="tanh"):
|
44
47
|
super().__init__()
|
48
|
+
self.approximate = approximate
|
45
49
|
|
46
50
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
47
51
|
d = x.shape[-1] // 2
|
48
|
-
return F.gelu(x[..., :d], approximate=
|
52
|
+
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
49
53
|
|
50
54
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
51
55
|
d = x.shape[-1] // 2
|
52
56
|
output_shape = x.shape[:-1] + (d,)
|
53
57
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
54
|
-
|
58
|
+
if self.approximate == "tanh":
|
59
|
+
gelu_tanh_and_mul(x, out)
|
60
|
+
elif self.approximate == "none":
|
61
|
+
gelu_and_mul(x, out)
|
62
|
+
else:
|
63
|
+
raise RuntimeError("GeluAndMul only support tanh or none")
|
55
64
|
return out
|
65
|
+
|
66
|
+
|
67
|
+
class ScaledActivation(nn.Module):
|
68
|
+
"""An activation function with post-scale parameters.
|
69
|
+
|
70
|
+
This is used for some quantization methods like AWQ.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
act_module: nn.Module,
|
76
|
+
intermediate_size: int,
|
77
|
+
input_is_parallel: bool = True,
|
78
|
+
params_dtype: Optional[torch.dtype] = None,
|
79
|
+
):
|
80
|
+
super().__init__()
|
81
|
+
self.act = act_module
|
82
|
+
self.input_is_parallel = input_is_parallel
|
83
|
+
if input_is_parallel:
|
84
|
+
tp_size = get_tensor_model_parallel_world_size()
|
85
|
+
intermediate_size_per_partition = divide(intermediate_size, tp_size)
|
86
|
+
else:
|
87
|
+
intermediate_size_per_partition = intermediate_size
|
88
|
+
if params_dtype is None:
|
89
|
+
params_dtype = torch.get_default_dtype()
|
90
|
+
self.scales = nn.Parameter(
|
91
|
+
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
|
92
|
+
)
|
93
|
+
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
94
|
+
|
95
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
96
|
+
return self.act(x) / self.scales
|
97
|
+
|
98
|
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
99
|
+
param_data = param.data
|
100
|
+
if self.input_is_parallel:
|
101
|
+
tp_rank = get_tensor_model_parallel_rank()
|
102
|
+
shard_size = param_data.shape[0]
|
103
|
+
start_idx = tp_rank * shard_size
|
104
|
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
105
|
+
assert param_data.shape == loaded_weight.shape
|
106
|
+
param_data.copy_(loaded_weight)
|
107
|
+
|
108
|
+
|
109
|
+
_ACTIVATION_REGISTRY = {
|
110
|
+
"gelu": nn.GELU(),
|
111
|
+
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
112
|
+
}
|
113
|
+
|
114
|
+
|
115
|
+
def get_act_fn(
|
116
|
+
act_fn_name: str,
|
117
|
+
quant_config: Optional[QuantizationConfig] = None,
|
118
|
+
intermediate_size: Optional[int] = None,
|
119
|
+
input_is_parallel: bool = True,
|
120
|
+
params_dtype: Optional[torch.dtype] = None,
|
121
|
+
) -> nn.Module:
|
122
|
+
"""Get an activation function by name."""
|
123
|
+
act_fn_name = act_fn_name.lower()
|
124
|
+
if act_fn_name not in _ACTIVATION_REGISTRY:
|
125
|
+
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
126
|
+
|
127
|
+
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
128
|
+
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
|
129
|
+
if intermediate_size is None:
|
130
|
+
raise ValueError(
|
131
|
+
"intermediate_size must be specified for scaled "
|
132
|
+
"activation functions."
|
133
|
+
)
|
134
|
+
return ScaledActivation(
|
135
|
+
act_fn, intermediate_size, input_is_parallel, params_dtype
|
136
|
+
)
|
137
|
+
return act_fn
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -19,7 +19,12 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
|
-
from flashinfer.norm import
|
22
|
+
from flashinfer.norm import (
|
23
|
+
fused_add_rmsnorm,
|
24
|
+
gemma_fused_add_rmsnorm,
|
25
|
+
gemma_rmsnorm,
|
26
|
+
rmsnorm,
|
27
|
+
)
|
23
28
|
from vllm.model_executor.custom_op import CustomOp
|
24
29
|
|
25
30
|
|
@@ -32,15 +37,12 @@ class RMSNorm(CustomOp):
|
|
32
37
|
super().__init__()
|
33
38
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
34
39
|
self.variance_epsilon = eps
|
35
|
-
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
|
36
40
|
|
37
41
|
def forward_cuda(
|
38
42
|
self,
|
39
43
|
x: torch.Tensor,
|
40
44
|
residual: Optional[torch.Tensor] = None,
|
41
45
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
42
|
-
if self.is_lower_sm80:
|
43
|
-
return self.forward_native(x, residual)
|
44
46
|
|
45
47
|
if residual is not None:
|
46
48
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
@@ -66,3 +68,44 @@ class RMSNorm(CustomOp):
|
|
66
68
|
return x
|
67
69
|
else:
|
68
70
|
return x, residual
|
71
|
+
|
72
|
+
|
73
|
+
class GemmaRMSNorm(CustomOp):
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
hidden_size: int,
|
77
|
+
eps: float = 1e-6,
|
78
|
+
) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
81
|
+
self.variance_epsilon = eps
|
82
|
+
|
83
|
+
def forward_native(
|
84
|
+
self,
|
85
|
+
x: torch.Tensor,
|
86
|
+
residual: Optional[torch.Tensor] = None,
|
87
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
88
|
+
orig_dtype = x.dtype
|
89
|
+
if residual is not None:
|
90
|
+
x = x + residual
|
91
|
+
residual = x
|
92
|
+
|
93
|
+
x = x.float()
|
94
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
95
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
96
|
+
x = x * (1.0 + self.weight.float())
|
97
|
+
x = x.to(orig_dtype)
|
98
|
+
return x if residual is None else (x, residual)
|
99
|
+
|
100
|
+
def forward_cuda(
|
101
|
+
self,
|
102
|
+
x: torch.Tensor,
|
103
|
+
residual: Optional[torch.Tensor] = None,
|
104
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
105
|
+
if residual is not None:
|
106
|
+
gemma_fused_add_rmsnorm(
|
107
|
+
x, residual, self.weight.data, self.variance_epsilon
|
108
|
+
)
|
109
|
+
return x, residual
|
110
|
+
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
111
|
+
return out
|
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class
|
32
|
+
class LogitProcessorOutput:
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
|
|
185
185
|
|
186
186
|
# Return only last_logits if logprob is not requested
|
187
187
|
if not logits_metadata.return_logprob:
|
188
|
-
return
|
188
|
+
return LogitProcessorOutput(
|
189
189
|
next_token_logits=last_logits,
|
190
190
|
next_token_logprobs=None,
|
191
191
|
normalized_prompt_logprobs=None,
|
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
209
209
|
else:
|
210
210
|
output_top_logprobs = None
|
211
211
|
|
212
|
-
return
|
212
|
+
return LogitProcessorOutput(
|
213
213
|
next_token_logits=last_logits,
|
214
214
|
next_token_logprobs=last_logprobs,
|
215
215
|
normalized_prompt_logprobs=None,
|
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
|
|
278
278
|
# Remove the last token logprob for the prefill tokens.
|
279
279
|
input_token_logprobs = input_token_logprobs[:-1]
|
280
280
|
|
281
|
-
return
|
281
|
+
return LogitProcessorOutput(
|
282
282
|
next_token_logits=last_logits,
|
283
283
|
next_token_logprobs=last_logprobs,
|
284
284
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
sglang/srt/layers/sampler.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
|
-
import dataclasses
|
2
1
|
import logging
|
3
|
-
from typing import Union
|
4
2
|
|
5
3
|
import torch
|
6
4
|
from flashinfer.sampling import (
|
@@ -11,8 +9,6 @@ from flashinfer.sampling import (
|
|
11
9
|
)
|
12
10
|
from vllm.model_executor.custom_op import CustomOp
|
13
11
|
|
14
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
15
|
-
|
16
12
|
# TODO: move this dict to another place
|
17
13
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
18
14
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -20,71 +16,30 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
20
16
|
logger = logging.getLogger(__name__)
|
21
17
|
|
22
18
|
|
23
|
-
@dataclasses.dataclass
|
24
|
-
class SampleOutput:
|
25
|
-
success: torch.Tensor
|
26
|
-
probs: torch.Tensor
|
27
|
-
batch_next_token_ids: torch.Tensor
|
28
|
-
|
29
|
-
|
30
19
|
class Sampler(CustomOp):
|
31
20
|
def __init__(self):
|
32
21
|
super().__init__()
|
33
22
|
|
34
|
-
def
|
35
|
-
# min-token, presence, frequency
|
36
|
-
if sampling_info.linear_penalties is not None:
|
37
|
-
logits += sampling_info.linear_penalties
|
38
|
-
|
39
|
-
# repetition
|
40
|
-
if sampling_info.scaling_penalties is not None:
|
41
|
-
logits = torch.where(
|
42
|
-
logits > 0,
|
43
|
-
logits / sampling_info.scaling_penalties,
|
44
|
-
logits * sampling_info.scaling_penalties,
|
45
|
-
)
|
46
|
-
|
47
|
-
return logits
|
48
|
-
|
49
|
-
def _get_probs(
|
50
|
-
self,
|
51
|
-
logits: torch.Tensor,
|
52
|
-
sampling_info: SamplingBatchInfo,
|
53
|
-
is_torch_compile: bool = False,
|
54
|
-
):
|
23
|
+
def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
55
24
|
# Post process logits
|
56
25
|
logits = logits.contiguous()
|
57
26
|
logits.div_(sampling_info.temperatures)
|
58
|
-
if is_torch_compile:
|
59
|
-
# FIXME: Temporary workaround for unknown bugs in torch.compile
|
60
|
-
logits.add_(0)
|
61
|
-
|
62
27
|
if sampling_info.logit_bias is not None:
|
63
28
|
logits.add_(sampling_info.logit_bias)
|
64
29
|
|
65
30
|
if sampling_info.vocab_mask is not None:
|
66
31
|
logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
|
67
32
|
|
68
|
-
logits =
|
33
|
+
logits = sampling_info.penalizer_orchestrator.apply(logits)
|
69
34
|
|
70
|
-
|
71
|
-
|
72
|
-
def forward_cuda(
|
73
|
-
self,
|
74
|
-
logits: Union[torch.Tensor, LogitsProcessorOutput],
|
75
|
-
sampling_info: SamplingBatchInfo,
|
76
|
-
):
|
77
|
-
if isinstance(logits, LogitsProcessorOutput):
|
78
|
-
logits = logits.next_token_logits
|
79
|
-
|
80
|
-
probs = self._get_probs(logits, sampling_info)
|
35
|
+
probs = torch.softmax(logits, dim=-1)
|
81
36
|
|
82
37
|
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
83
38
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
84
39
|
uniform_samples = torch.rand(
|
85
40
|
(max_top_k_round, batch_size), device=probs.device
|
86
41
|
)
|
87
|
-
if sampling_info.
|
42
|
+
if sampling_info.min_ps.any():
|
88
43
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
89
44
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
90
45
|
batch_next_token_ids, success = min_p_sampling_from_probs(
|
@@ -100,23 +55,18 @@ class Sampler(CustomOp):
|
|
100
55
|
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
101
56
|
)
|
102
57
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
if isinstance(logits, LogitsProcessorOutput):
|
111
|
-
logits = logits.next_token_logits
|
112
|
-
|
113
|
-
probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
|
58
|
+
if not torch.all(success):
|
59
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
60
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
61
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
62
|
+
batch_next_token_ids = torch.where(
|
63
|
+
success, batch_next_token_ids, argmax_ids
|
64
|
+
)
|
114
65
|
|
115
|
-
batch_next_token_ids
|
116
|
-
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
|
117
|
-
)
|
66
|
+
return batch_next_token_ids
|
118
67
|
|
119
|
-
|
68
|
+
def forward_native():
|
69
|
+
raise NotImplementedError("Native forward is not implemented yet.")
|
120
70
|
|
121
71
|
|
122
72
|
def top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -137,10 +87,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
137
87
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
138
88
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
139
89
|
try:
|
140
|
-
|
141
|
-
sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
|
142
|
-
:, :1
|
143
|
-
]
|
90
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
144
91
|
except RuntimeError as e:
|
145
92
|
logger.warning(f"Sampling error: {e}")
|
146
93
|
batch_next_token_ids = torch.zeros(
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -55,6 +55,7 @@ class GenerateReqInput:
|
|
55
55
|
self.text is not None and self.input_ids is not None
|
56
56
|
):
|
57
57
|
raise ValueError("Either text or input_ids should be provided.")
|
58
|
+
|
58
59
|
if (
|
59
60
|
isinstance(self.sampling_params, dict)
|
60
61
|
and self.sampling_params.get("n", 1) != 1
|
@@ -161,10 +162,10 @@ class TokenizedGenerateReqInput:
|
|
161
162
|
input_ids: List[int]
|
162
163
|
# The pixel values for input images
|
163
164
|
pixel_values: List[float]
|
164
|
-
# The hash of input images
|
165
|
-
|
166
|
-
# The image
|
167
|
-
|
165
|
+
# The hash values of input images
|
166
|
+
image_hashes: List[int]
|
167
|
+
# The image sizes
|
168
|
+
image_sizes: List[List[int]]
|
168
169
|
# The sampling parameters
|
169
170
|
sampling_params: SamplingParams
|
170
171
|
# Whether to return the logprobs
|