sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/hf_transformers_utils.py +0 -149
- sglang/srt/layers/activation.py +93 -11
- sglang/srt/layers/layernorm.py +47 -4
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +15 -68
- sglang/srt/managers/io_struct.py +5 -4
- sglang/srt/managers/schedule_batch.py +20 -25
- sglang/srt/managers/tokenizer_manager.py +74 -61
- sglang/srt/managers/tp_worker.py +49 -43
- sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- sglang/srt/model_executor/forward_batch_info.py +9 -26
- sglang/srt/model_executor/model_runner.py +20 -17
- sglang/srt/models/chatglm.py +13 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/gemma.py +3 -7
- sglang/srt/models/gemma2.py +2 -56
- sglang/srt/models/gpt_bigcode.py +2 -6
- sglang/srt/models/grok.py +10 -8
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama2.py +6 -11
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/mixtral.py +1 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +2 -5
- sglang/srt/models/qwen2.py +5 -10
- sglang/srt/models/qwen2_moe.py +21 -24
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +85 -4
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -74
- sglang/srt/sampling/sampling_params.py +4 -0
- sglang/srt/server.py +11 -4
- sglang/srt/utils.py +18 -33
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import bisect
|
19
19
|
from contextlib import contextmanager
|
20
|
+
from typing import Callable, List
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
@@ -25,18 +26,16 @@ from vllm.distributed.parallel_state import graph_capture
|
|
25
26
|
from vllm.model_executor.custom_op import CustomOp
|
26
27
|
|
27
28
|
from sglang.srt.layers.logits_processor import (
|
29
|
+
LogitProcessorOutput,
|
28
30
|
LogitsMetadata,
|
29
31
|
LogitsProcessor,
|
30
|
-
LogitsProcessorOutput,
|
31
32
|
)
|
32
|
-
from sglang.srt.layers.sampler import SampleOutput
|
33
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import (
|
35
35
|
ForwardMode,
|
36
36
|
InputMetadata,
|
37
37
|
update_flashinfer_indices,
|
38
38
|
)
|
39
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
40
39
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
41
40
|
|
42
41
|
|
@@ -53,12 +52,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
53
52
|
|
54
53
|
@contextmanager
|
55
54
|
def patch_model(
|
56
|
-
model: torch.nn.Module,
|
55
|
+
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
57
56
|
):
|
58
57
|
backup_ca_comm = None
|
59
58
|
|
60
59
|
try:
|
61
|
-
if
|
60
|
+
if enable_compile:
|
62
61
|
_to_torch(model)
|
63
62
|
monkey_patch_vllm_all_gather()
|
64
63
|
backup_ca_comm = tp_group.ca_comm
|
@@ -67,7 +66,7 @@ def patch_model(
|
|
67
66
|
else:
|
68
67
|
yield model.forward
|
69
68
|
finally:
|
70
|
-
if
|
69
|
+
if enable_compile:
|
71
70
|
_to_torch(model, reverse=True)
|
72
71
|
monkey_patch_vllm_all_gather(reverse=True)
|
73
72
|
tp_group.ca_comm = backup_ca_comm
|
@@ -88,7 +87,7 @@ def set_torch_compile_config():
|
|
88
87
|
class CudaGraphRunner:
|
89
88
|
def __init__(
|
90
89
|
self,
|
91
|
-
model_runner,
|
90
|
+
model_runner: "ModelRunner",
|
92
91
|
max_batch_size_to_capture: int,
|
93
92
|
use_torch_compile: bool,
|
94
93
|
disable_padding: bool,
|
@@ -145,22 +144,18 @@ class CudaGraphRunner:
|
|
145
144
|
self.flashinfer_kv_indices.clone(),
|
146
145
|
]
|
147
146
|
|
148
|
-
# Sampling inputs
|
149
|
-
vocab_size = model_runner.model_config.vocab_size
|
150
|
-
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
151
|
-
|
152
147
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
153
148
|
|
154
149
|
if use_torch_compile:
|
155
150
|
set_torch_compile_config()
|
156
151
|
|
157
|
-
def can_run(self, batch_size):
|
152
|
+
def can_run(self, batch_size: int):
|
158
153
|
if self.disable_padding:
|
159
154
|
return batch_size in self.graphs
|
160
155
|
else:
|
161
156
|
return batch_size <= self.max_bs
|
162
157
|
|
163
|
-
def capture(self, batch_size_list):
|
158
|
+
def capture(self, batch_size_list: List[int]):
|
164
159
|
self.batch_size_list = batch_size_list
|
165
160
|
with graph_capture() as graph_capture_context:
|
166
161
|
self.stream = graph_capture_context.stream
|
@@ -181,7 +176,7 @@ class CudaGraphRunner:
|
|
181
176
|
self.output_buffers[bs] = output_buffers
|
182
177
|
self.flashinfer_handlers[bs] = flashinfer_handler
|
183
178
|
|
184
|
-
def capture_one_batch_size(self, bs, forward):
|
179
|
+
def capture_one_batch_size(self, bs: int, forward: Callable):
|
185
180
|
graph = torch.cuda.CUDAGraph()
|
186
181
|
stream = self.stream
|
187
182
|
|
@@ -240,7 +235,6 @@ class CudaGraphRunner:
|
|
240
235
|
def run_once():
|
241
236
|
input_metadata = InputMetadata(
|
242
237
|
forward_mode=ForwardMode.DECODE,
|
243
|
-
sampling_info=self.sampling_info[:bs],
|
244
238
|
batch_size=bs,
|
245
239
|
req_pool_indices=req_pool_indices,
|
246
240
|
seq_lens=seq_lens,
|
@@ -305,35 +299,27 @@ class CudaGraphRunner:
|
|
305
299
|
self.flashinfer_handlers[bs],
|
306
300
|
)
|
307
301
|
|
308
|
-
# Sampling inputs
|
309
|
-
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
310
|
-
|
311
302
|
# Replay
|
312
303
|
torch.cuda.synchronize()
|
313
304
|
self.graphs[bs].replay()
|
314
305
|
torch.cuda.synchronize()
|
315
|
-
|
306
|
+
output = self.output_buffers[bs]
|
316
307
|
|
317
308
|
# Unpad
|
318
309
|
if bs != raw_bs:
|
319
|
-
|
320
|
-
next_token_logits=
|
310
|
+
output = LogitProcessorOutput(
|
311
|
+
next_token_logits=output.next_token_logits[:raw_bs],
|
321
312
|
next_token_logprobs=None,
|
322
313
|
normalized_prompt_logprobs=None,
|
323
314
|
input_token_logprobs=None,
|
324
315
|
input_top_logprobs=None,
|
325
316
|
output_top_logprobs=None,
|
326
317
|
)
|
327
|
-
sample_output = SampleOutput(
|
328
|
-
sample_output.success[:raw_bs],
|
329
|
-
sample_output.probs[:raw_bs],
|
330
|
-
sample_output.batch_next_token_ids[:raw_bs],
|
331
|
-
)
|
332
318
|
|
333
319
|
# Extract logprobs
|
334
320
|
if batch.return_logprob:
|
335
|
-
|
336
|
-
|
321
|
+
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
322
|
+
output.next_token_logits, dim=-1
|
337
323
|
)
|
338
324
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
339
325
|
if return_top_logprob:
|
@@ -341,8 +327,8 @@ class CudaGraphRunner:
|
|
341
327
|
forward_mode=ForwardMode.DECODE,
|
342
328
|
top_logprobs_nums=batch.top_logprobs_nums,
|
343
329
|
)
|
344
|
-
|
345
|
-
|
330
|
+
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
331
|
+
output.next_token_logprobs, logits_metadata
|
346
332
|
)[1]
|
347
333
|
|
348
|
-
return
|
334
|
+
return output
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
"""
|
4
2
|
Copyright 2023-2024 SGLang Team
|
5
3
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -28,7 +26,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
|
28
26
|
|
29
27
|
if TYPE_CHECKING:
|
30
28
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
32
29
|
|
33
30
|
|
34
31
|
class ForwardMode(IntEnum):
|
@@ -45,7 +42,6 @@ class InputMetadata:
|
|
45
42
|
"""Store all inforamtion of a forward pass."""
|
46
43
|
|
47
44
|
forward_mode: ForwardMode
|
48
|
-
sampling_info: SamplingBatchInfo
|
49
45
|
batch_size: int
|
50
46
|
req_pool_indices: torch.Tensor
|
51
47
|
seq_lens: torch.Tensor
|
@@ -62,6 +58,7 @@ class InputMetadata:
|
|
62
58
|
|
63
59
|
# For extend
|
64
60
|
extend_seq_lens: torch.Tensor = None
|
61
|
+
extend_prefix_lens: torch.Tensor = None
|
65
62
|
extend_start_loc: torch.Tensor = None
|
66
63
|
extend_no_prefix: bool = None
|
67
64
|
|
@@ -73,8 +70,8 @@ class InputMetadata:
|
|
73
70
|
|
74
71
|
# For multimodal
|
75
72
|
pixel_values: List[torch.Tensor] = None
|
76
|
-
image_sizes: List[List[int]] = None
|
77
|
-
image_offsets: List[int] = None
|
73
|
+
image_sizes: List[List[List[int]]] = None
|
74
|
+
image_offsets: List[List[int]] = None
|
78
75
|
|
79
76
|
# Trition attention backend
|
80
77
|
triton_max_seq_len: int = 0
|
@@ -91,20 +88,8 @@ class InputMetadata:
|
|
91
88
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
92
89
|
reqs = batch.reqs
|
93
90
|
self.pixel_values = [r.pixel_values for r in reqs]
|
94
|
-
self.image_sizes = [r.
|
95
|
-
self.image_offsets = []
|
96
|
-
for r in reqs:
|
97
|
-
if isinstance(r.image_offset, list):
|
98
|
-
self.image_offsets.append(
|
99
|
-
[
|
100
|
-
(image_offset - len(r.prefix_indices))
|
101
|
-
for image_offset in r.image_offset
|
102
|
-
]
|
103
|
-
)
|
104
|
-
elif isinstance(r.image_offset, int):
|
105
|
-
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
106
|
-
elif r.image_offset is None:
|
107
|
-
self.image_offsets.append(0)
|
91
|
+
self.image_sizes = [r.image_sizes for r in reqs]
|
92
|
+
self.image_offsets = [r.image_offsets for r in reqs]
|
108
93
|
|
109
94
|
def compute_positions(self, batch: ScheduleBatch):
|
110
95
|
position_ids_offsets = batch.position_ids_offsets
|
@@ -157,6 +142,7 @@ class InputMetadata:
|
|
157
142
|
for i, r in enumerate(batch.reqs)
|
158
143
|
]
|
159
144
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
145
|
+
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
160
146
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
161
147
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
162
148
|
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
@@ -183,7 +169,6 @@ class InputMetadata:
|
|
183
169
|
):
|
184
170
|
ret = cls(
|
185
171
|
forward_mode=forward_mode,
|
186
|
-
sampling_info=batch.sampling_info,
|
187
172
|
batch_size=batch.batch_size(),
|
188
173
|
req_pool_indices=batch.req_pool_indices,
|
189
174
|
seq_lens=batch.seq_lens,
|
@@ -194,8 +179,6 @@ class InputMetadata:
|
|
194
179
|
top_logprobs_nums=batch.top_logprobs_nums,
|
195
180
|
)
|
196
181
|
|
197
|
-
ret.sampling_info.prepare_penalties()
|
198
|
-
|
199
182
|
ret.compute_positions(batch)
|
200
183
|
|
201
184
|
ret.compute_extend_infos(batch)
|
@@ -245,10 +228,10 @@ class InputMetadata:
|
|
245
228
|
prefix_lens_cpu,
|
246
229
|
flashinfer_use_ragged,
|
247
230
|
):
|
248
|
-
if self.forward_mode
|
249
|
-
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
250
|
-
else:
|
231
|
+
if self.forward_mode == ForwardMode.DECODE:
|
251
232
|
prefix_lens = None
|
233
|
+
else:
|
234
|
+
prefix_lens = self.extend_prefix_lens
|
252
235
|
|
253
236
|
update_flashinfer_indices(
|
254
237
|
self.forward_mode,
|
@@ -21,7 +21,7 @@ import importlib.resources
|
|
21
21
|
import logging
|
22
22
|
import pkgutil
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional,
|
24
|
+
from typing import Optional, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -44,15 +44,13 @@ from vllm.model_executor.model_loader import get_model
|
|
44
44
|
from vllm.model_executor.models import ModelRegistry
|
45
45
|
|
46
46
|
from sglang.global_config import global_config
|
47
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.sampler import SampleOutput
|
49
47
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
48
|
from sglang.srt.mem_cache.memory_pool import (
|
51
49
|
MHATokenToKVPool,
|
52
50
|
MLATokenToKVPool,
|
53
51
|
ReqToTokenPool,
|
54
52
|
)
|
55
|
-
from sglang.srt.model_config import AttentionArch
|
53
|
+
from sglang.srt.model_config import AttentionArch, ModelConfig
|
56
54
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
57
55
|
from sglang.srt.server_args import ServerArgs
|
58
56
|
from sglang.srt.utils import (
|
@@ -71,7 +69,7 @@ logger = logging.getLogger(__name__)
|
|
71
69
|
class ModelRunner:
|
72
70
|
def __init__(
|
73
71
|
self,
|
74
|
-
model_config,
|
72
|
+
model_config: ModelConfig,
|
75
73
|
mem_fraction_static: float,
|
76
74
|
gpu_id: int,
|
77
75
|
tp_rank: int,
|
@@ -87,7 +85,9 @@ class ModelRunner:
|
|
87
85
|
self.tp_size = tp_size
|
88
86
|
self.nccl_port = nccl_port
|
89
87
|
self.server_args = server_args
|
90
|
-
self.is_multimodal_model = is_multimodal_model(
|
88
|
+
self.is_multimodal_model = is_multimodal_model(
|
89
|
+
self.model_config.hf_config.architectures
|
90
|
+
)
|
91
91
|
global_server_args_dict.update(
|
92
92
|
{
|
93
93
|
"disable_flashinfer": server_args.disable_flashinfer,
|
@@ -97,6 +97,13 @@ class ModelRunner:
|
|
97
97
|
}
|
98
98
|
)
|
99
99
|
|
100
|
+
if self.is_multimodal_model:
|
101
|
+
logger.info(
|
102
|
+
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
103
|
+
)
|
104
|
+
server_args.chunked_prefill_size = None
|
105
|
+
server_args.mem_fraction_static *= 0.95
|
106
|
+
|
100
107
|
min_per_gpu_memory = self.init_torch_distributed()
|
101
108
|
self.load_model()
|
102
109
|
self.init_memory_pool(
|
@@ -161,6 +168,8 @@ class ModelRunner:
|
|
161
168
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
162
169
|
)
|
163
170
|
self.server_args.dtype = "float16"
|
171
|
+
if torch.cuda.get_device_capability()[1] < 5:
|
172
|
+
raise RuntimeError("SGLang only supports sm75 and above.")
|
164
173
|
|
165
174
|
monkey_patch_vllm_dummy_weight_loader()
|
166
175
|
self.device_config = DeviceConfig()
|
@@ -507,19 +516,15 @@ class ModelRunner:
|
|
507
516
|
raise Exception(
|
508
517
|
f"Capture cuda graph failed: {e}\n"
|
509
518
|
"Possible solutions:\n"
|
510
|
-
"1. disable
|
511
|
-
"2.
|
512
|
-
"3.
|
519
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
520
|
+
"2. set --mem-fraction-static to a smaller value\n"
|
521
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
513
522
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
514
523
|
)
|
515
524
|
|
516
525
|
@torch.inference_mode()
|
517
526
|
def forward_decode(self, batch: ScheduleBatch):
|
518
|
-
if (
|
519
|
-
self.cuda_graph_runner
|
520
|
-
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
521
|
-
and not batch.sampling_info.has_bias()
|
522
|
-
):
|
527
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
523
528
|
return self.cuda_graph_runner.replay(batch)
|
524
529
|
|
525
530
|
input_metadata = InputMetadata.from_schedule_batch(
|
@@ -568,9 +573,7 @@ class ModelRunner:
|
|
568
573
|
input_metadata.image_offsets,
|
569
574
|
)
|
570
575
|
|
571
|
-
def forward(
|
572
|
-
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
573
|
-
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
576
|
+
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
574
577
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
575
578
|
return self.forward_extend_multi_modal(batch)
|
576
579
|
elif forward_mode == ForwardMode.DECODE:
|
sglang/srt/models/chatglm.py
CHANGED
@@ -17,7 +17,7 @@ limitations under the License.
|
|
17
17
|
# Adapted from
|
18
18
|
# https://github.com/THUDM/ChatGLM2-6B
|
19
19
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
20
|
-
from typing import Iterable,
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
@@ -31,18 +31,20 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
)
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
+
from vllm.model_executor.layers.sampler import Sampler
|
34
35
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35
36
|
ParallelLMHead,
|
36
37
|
VocabParallelEmbedding,
|
37
38
|
)
|
38
39
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
41
|
+
from vllm.sequence import SamplerOutput
|
39
42
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
40
43
|
|
41
44
|
from sglang.srt.layers.activation import SiluAndMul
|
42
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
43
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
-
from sglang.srt.layers.sampler import Sampler
|
46
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
47
49
|
|
48
50
|
LoraConfig = None
|
@@ -381,11 +383,17 @@ class ChatGLMForCausalLM(nn.Module):
|
|
381
383
|
input_metadata: InputMetadata,
|
382
384
|
) -> torch.Tensor:
|
383
385
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
384
|
-
|
386
|
+
return self.logits_processor(
|
385
387
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
386
388
|
)
|
387
|
-
|
388
|
-
|
389
|
+
|
390
|
+
def sample(
|
391
|
+
self,
|
392
|
+
logits: torch.Tensor,
|
393
|
+
sampling_metadata: SamplingMetadata,
|
394
|
+
) -> Optional[SamplerOutput]:
|
395
|
+
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
+
return next_tokens
|
389
397
|
|
390
398
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
391
399
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
-
from sglang.srt.layers.sampler import Sampler
|
68
67
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
69
68
|
|
70
69
|
|
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
|
|
327
326
|
self.config = config
|
328
327
|
self.quant_config = quant_config
|
329
328
|
self.logits_processor = LogitsProcessor(config)
|
330
|
-
self.sampler = Sampler()
|
331
329
|
self.model = CohereModel(config, quant_config)
|
332
330
|
|
333
331
|
@torch.no_grad()
|
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
|
|
342
340
|
positions,
|
343
341
|
input_metadata,
|
344
342
|
)
|
345
|
-
|
343
|
+
return self.logits_processor(
|
346
344
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
347
345
|
)
|
348
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
-
return sample_output, logits_output
|
350
346
|
|
351
347
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
352
348
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.layers.sampler import Sampler
|
49
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
49
|
|
51
50
|
|
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
|
|
383
382
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
384
383
|
)
|
385
384
|
self.logits_processor = LogitsProcessor(config)
|
386
|
-
self.sampler = Sampler()
|
387
385
|
|
388
386
|
@torch.no_grad()
|
389
387
|
def forward(
|
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
|
|
393
391
|
input_metadata: InputMetadata,
|
394
392
|
) -> torch.Tensor:
|
395
393
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
396
|
-
|
394
|
+
return self.logits_processor(
|
397
395
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
398
396
|
)
|
399
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
-
return sample_output, logits_output
|
401
397
|
|
402
398
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
403
399
|
expert_params_mapping = [
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.layers.sampler import Sampler
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
52
51
|
|
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
|
|
386
385
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
387
386
|
)
|
388
387
|
self.logits_processor = LogitsProcessor(config)
|
389
|
-
self.sampler = Sampler()
|
390
388
|
|
391
389
|
@torch.no_grad()
|
392
390
|
def forward(
|
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
|
|
396
394
|
input_metadata: InputMetadata,
|
397
395
|
) -> torch.Tensor:
|
398
396
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
399
|
-
|
397
|
+
return self.logits_processor(
|
400
398
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
401
399
|
)
|
402
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
403
|
-
return sample_output, logits_output
|
404
400
|
|
405
401
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
406
402
|
stacked_params_mapping = [
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -45,7 +45,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
45
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.layers.sampler import Sampler
|
49
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
@@ -633,7 +632,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
633
632
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
634
633
|
)
|
635
634
|
self.logits_processor = LogitsProcessor(config)
|
636
|
-
self.sampler = Sampler()
|
637
635
|
|
638
636
|
def forward(
|
639
637
|
self,
|
@@ -642,11 +640,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
642
640
|
input_metadata: InputMetadata,
|
643
641
|
) -> torch.Tensor:
|
644
642
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
645
|
-
|
643
|
+
return self.logits_processor(
|
646
644
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
647
645
|
)
|
648
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
649
|
-
return sample_output, logits_output
|
650
646
|
|
651
647
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
652
648
|
stacked_params_mapping = [
|
sglang/srt/models/gemma.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.config import CacheConfig, LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import GeluAndMul
|
27
26
|
from vllm.model_executor.layers.linear import (
|
28
27
|
MergedColumnParallelLinear,
|
29
28
|
QKVParallelLinear,
|
@@ -34,10 +33,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
34
33
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
35
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
36
35
|
|
36
|
+
from sglang.srt.layers.activation import GeluAndMul
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.layers.sampler import Sampler
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
42
41
|
|
43
42
|
|
@@ -61,7 +60,7 @@ class GemmaMLP(nn.Module):
|
|
61
60
|
bias=False,
|
62
61
|
quant_config=quant_config,
|
63
62
|
)
|
64
|
-
self.act_fn = GeluAndMul()
|
63
|
+
self.act_fn = GeluAndMul("none")
|
65
64
|
|
66
65
|
def forward(self, x):
|
67
66
|
gate_up, _ = self.gate_up_proj(x)
|
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
|
|
288
287
|
self.quant_config = quant_config
|
289
288
|
self.model = GemmaModel(config, quant_config=quant_config)
|
290
289
|
self.logits_processor = LogitsProcessor(config)
|
291
|
-
self.sampler = Sampler()
|
292
290
|
|
293
291
|
@torch.no_grad()
|
294
292
|
def forward(
|
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
|
|
299
297
|
input_embeds: torch.Tensor = None,
|
300
298
|
) -> torch.Tensor:
|
301
299
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
302
|
-
|
300
|
+
return self.logits_processor(
|
303
301
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
304
302
|
)
|
305
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
306
|
-
return (sample_output, logits_output)
|
307
303
|
|
308
304
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
309
305
|
stacked_params_mapping = [
|
sglang/srt/models/gemma2.py
CHANGED
@@ -22,11 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.config import CacheConfig, LoRAConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
|
26
|
-
# FIXME: temporary solution, remove after next vllm release
|
27
|
-
from vllm.model_executor.custom_op import CustomOp
|
28
|
-
|
29
|
-
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
30
25
|
from vllm.model_executor.layers.linear import (
|
31
26
|
MergedColumnParallelLinear,
|
32
27
|
QKVParallelLinear,
|
@@ -39,9 +34,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
|
39
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
35
|
|
41
36
|
from sglang.srt.layers.activation import GeluAndMul
|
37
|
+
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
42
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.layers.sampler import Sampler
|
45
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
41
|
|
47
42
|
|
@@ -51,52 +46,6 @@ def get_attention_sliding_window_size(config):
|
|
51
46
|
return config.sliding_window - 1
|
52
47
|
|
53
48
|
|
54
|
-
class GemmaRMSNorm(CustomOp):
|
55
|
-
"""RMS normalization for Gemma.
|
56
|
-
|
57
|
-
Two differences from the above RMSNorm:
|
58
|
-
1. x * (1 + w) instead of x * w.
|
59
|
-
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
60
|
-
"""
|
61
|
-
|
62
|
-
def __init__(
|
63
|
-
self,
|
64
|
-
hidden_size: int,
|
65
|
-
eps: float = 1e-6,
|
66
|
-
) -> None:
|
67
|
-
super().__init__()
|
68
|
-
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
69
|
-
self.variance_epsilon = eps
|
70
|
-
|
71
|
-
def forward_native(
|
72
|
-
self,
|
73
|
-
x: torch.Tensor,
|
74
|
-
residual: Optional[torch.Tensor] = None,
|
75
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
76
|
-
"""PyTorch-native implementation equivalent to forward()."""
|
77
|
-
orig_dtype = x.dtype
|
78
|
-
if residual is not None:
|
79
|
-
x = x + residual
|
80
|
-
residual = x
|
81
|
-
|
82
|
-
x = x.float()
|
83
|
-
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
84
|
-
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
85
|
-
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
86
|
-
# See https://github.com/huggingface/transformers/pull/29402
|
87
|
-
x = x * (1.0 + self.weight.float())
|
88
|
-
x = x.to(orig_dtype)
|
89
|
-
return x if residual is None else (x, residual)
|
90
|
-
|
91
|
-
def forward_cuda(
|
92
|
-
self,
|
93
|
-
x: torch.Tensor,
|
94
|
-
residual: Optional[torch.Tensor] = None,
|
95
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
96
|
-
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
97
|
-
return self.forward_native(x, residual)
|
98
|
-
|
99
|
-
|
100
49
|
# FIXME: temporary solution, remove after next vllm release
|
101
50
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
102
51
|
|
@@ -397,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|
397
346
|
self.quant_config = quant_config
|
398
347
|
self.model = Gemma2Model(config, cache_config, quant_config)
|
399
348
|
self.logits_processor = LogitsProcessor(config)
|
400
|
-
self.sampler = Sampler()
|
401
349
|
|
402
350
|
@torch.no_grad()
|
403
351
|
def forward(
|
@@ -408,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
|
|
408
356
|
input_embeds: torch.Tensor = None,
|
409
357
|
) -> torch.Tensor:
|
410
358
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
411
|
-
|
359
|
+
return self.logits_processor(
|
412
360
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
413
361
|
)
|
414
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
415
|
-
return sample_output, logits_output
|
416
362
|
|
417
363
|
def get_attention_sliding_window_size(self):
|
418
364
|
return get_attention_sliding_window_size(self.config)
|