sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +151 -40
- sglang/bench_serving.py +46 -22
- sglang/check_env.py +24 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -29
- sglang/lang/choices.py +164 -0
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +14 -5
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +6 -1
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +4 -7
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +174 -380
- sglang/srt/managers/tokenizer_manager.py +197 -112
- sglang/srt/managers/tp_worker.py +299 -364
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +10 -15
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +27 -12
- sglang/srt/model_executor/forward_batch_info.py +319 -0
- sglang/srt/model_executor/model_runner.py +30 -47
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +1 -1
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -2
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +3 -8
- sglang/srt/models/llama2.py +5 -5
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -12
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +189 -39
- sglang/srt/openai_api/protocol.py +43 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +93 -21
- sglang/srt/server_args.py +30 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +21 -3
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.10.dist-info/RECORD +0 -100
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,319 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""ModelRunner runs the forward passes of the models."""
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from enum import IntEnum, auto
|
19
|
+
from typing import TYPE_CHECKING, List
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import torch
|
23
|
+
|
24
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
25
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
26
|
+
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
29
|
+
|
30
|
+
|
31
|
+
class ForwardMode(IntEnum):
|
32
|
+
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
33
|
+
PREFILL = auto()
|
34
|
+
# Extend a sequence. The KV cache of the first part of the sequence is already computed (e.g., system prompt).
|
35
|
+
EXTEND = auto()
|
36
|
+
# Decode one token.
|
37
|
+
DECODE = auto()
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class InputMetadata:
|
42
|
+
"""Store all inforamtion of a forward pass."""
|
43
|
+
|
44
|
+
forward_mode: ForwardMode
|
45
|
+
batch_size: int
|
46
|
+
req_pool_indices: torch.Tensor
|
47
|
+
seq_lens: torch.Tensor
|
48
|
+
req_to_token_pool: ReqToTokenPool
|
49
|
+
token_to_kv_pool: BaseTokenToKVPool
|
50
|
+
|
51
|
+
# Output location of the KV cache
|
52
|
+
out_cache_loc: torch.Tensor
|
53
|
+
|
54
|
+
total_num_tokens: int = None
|
55
|
+
|
56
|
+
# Position information
|
57
|
+
positions: torch.Tensor = None
|
58
|
+
|
59
|
+
# For extend
|
60
|
+
extend_seq_lens: torch.Tensor = None
|
61
|
+
extend_start_loc: torch.Tensor = None
|
62
|
+
extend_no_prefix: bool = None
|
63
|
+
|
64
|
+
# Output options
|
65
|
+
return_logprob: bool = False
|
66
|
+
top_logprobs_nums: List[int] = None
|
67
|
+
|
68
|
+
# For multimodal
|
69
|
+
pixel_values: List[torch.Tensor] = None
|
70
|
+
image_sizes: List[List[int]] = None
|
71
|
+
image_offsets: List[int] = None
|
72
|
+
|
73
|
+
# Trition attention backend
|
74
|
+
triton_max_seq_len: int = 0
|
75
|
+
triton_max_extend_len: int = 0
|
76
|
+
triton_start_loc: torch.Tensor = None
|
77
|
+
triton_prefix_lens: torch.Tensor = None
|
78
|
+
|
79
|
+
# FlashInfer attention backend
|
80
|
+
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
81
|
+
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
82
|
+
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
83
|
+
flashinfer_use_ragged: bool = False
|
84
|
+
|
85
|
+
def init_multimuldal_info(self, batch: ScheduleBatch):
|
86
|
+
reqs = batch.reqs
|
87
|
+
self.pixel_values = [r.pixel_values for r in reqs]
|
88
|
+
self.image_sizes = [r.image_size for r in reqs]
|
89
|
+
self.image_offsets = [
|
90
|
+
(
|
91
|
+
(r.image_offset - len(r.prefix_indices))
|
92
|
+
if r.image_offset is not None
|
93
|
+
else 0
|
94
|
+
)
|
95
|
+
for r in reqs
|
96
|
+
]
|
97
|
+
|
98
|
+
def compute_positions(self, batch: ScheduleBatch):
|
99
|
+
position_ids_offsets = batch.position_ids_offsets
|
100
|
+
|
101
|
+
if self.forward_mode == ForwardMode.DECODE:
|
102
|
+
if True:
|
103
|
+
self.positions = self.seq_lens - 1
|
104
|
+
else:
|
105
|
+
# Deprecated
|
106
|
+
self.positions = (self.seq_lens - 1) + position_ids_offsets
|
107
|
+
else:
|
108
|
+
if True:
|
109
|
+
self.positions = torch.tensor(
|
110
|
+
np.concatenate(
|
111
|
+
[
|
112
|
+
np.arange(len(req.prefix_indices), len(req.fill_ids))
|
113
|
+
for req in batch.reqs
|
114
|
+
],
|
115
|
+
axis=0,
|
116
|
+
),
|
117
|
+
device="cuda",
|
118
|
+
)
|
119
|
+
else:
|
120
|
+
# Deprecated
|
121
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
122
|
+
self.positions = torch.tensor(
|
123
|
+
np.concatenate(
|
124
|
+
[
|
125
|
+
np.arange(
|
126
|
+
len(req.prefix_indices) + position_ids_offsets_cpu[i],
|
127
|
+
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
128
|
+
)
|
129
|
+
for i, req in enumerate(batch.reqs)
|
130
|
+
],
|
131
|
+
axis=0,
|
132
|
+
),
|
133
|
+
device="cuda",
|
134
|
+
)
|
135
|
+
|
136
|
+
# Positions should be in long type
|
137
|
+
self.positions = self.positions.to(torch.int64)
|
138
|
+
|
139
|
+
def compute_extend_infos(self, batch: ScheduleBatch):
|
140
|
+
if self.forward_mode == ForwardMode.DECODE:
|
141
|
+
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
142
|
+
else:
|
143
|
+
extend_lens_cpu = [
|
144
|
+
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
145
|
+
]
|
146
|
+
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
147
|
+
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
148
|
+
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
149
|
+
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def from_schedule_batch(
|
153
|
+
cls,
|
154
|
+
model_runner: "ModelRunner",
|
155
|
+
batch: ScheduleBatch,
|
156
|
+
forward_mode: ForwardMode,
|
157
|
+
):
|
158
|
+
ret = cls(
|
159
|
+
forward_mode=forward_mode,
|
160
|
+
batch_size=batch.batch_size(),
|
161
|
+
req_pool_indices=batch.req_pool_indices,
|
162
|
+
seq_lens=batch.seq_lens,
|
163
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
164
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
165
|
+
out_cache_loc=batch.out_cache_loc,
|
166
|
+
return_logprob=batch.return_logprob,
|
167
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
168
|
+
)
|
169
|
+
|
170
|
+
ret.compute_positions(batch)
|
171
|
+
|
172
|
+
ret.compute_extend_infos(batch)
|
173
|
+
|
174
|
+
if (
|
175
|
+
forward_mode != ForwardMode.DECODE
|
176
|
+
or model_runner.server_args.disable_flashinfer
|
177
|
+
):
|
178
|
+
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
179
|
+
|
180
|
+
if forward_mode != ForwardMode.DECODE:
|
181
|
+
ret.init_multimuldal_info(batch)
|
182
|
+
|
183
|
+
prefix_lens = None
|
184
|
+
if forward_mode != ForwardMode.DECODE:
|
185
|
+
prefix_lens = torch.tensor(
|
186
|
+
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
187
|
+
)
|
188
|
+
|
189
|
+
if model_runner.server_args.disable_flashinfer:
|
190
|
+
ret.init_triton_args(batch, prefix_lens)
|
191
|
+
|
192
|
+
flashinfer_use_ragged = False
|
193
|
+
if not model_runner.server_args.disable_flashinfer:
|
194
|
+
if (
|
195
|
+
forward_mode != ForwardMode.DECODE
|
196
|
+
and int(torch.sum(ret.seq_lens)) > 4096
|
197
|
+
):
|
198
|
+
flashinfer_use_ragged = True
|
199
|
+
ret.init_flashinfer_handlers(
|
200
|
+
model_runner, prefix_lens, flashinfer_use_ragged
|
201
|
+
)
|
202
|
+
|
203
|
+
return ret
|
204
|
+
|
205
|
+
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
206
|
+
"""Init auxiliary variables for triton attention backend."""
|
207
|
+
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
208
|
+
self.triton_prefix_lens = prefix_lens
|
209
|
+
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
210
|
+
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
211
|
+
|
212
|
+
if self.forward_mode == ForwardMode.DECODE:
|
213
|
+
self.triton_max_extend_len = None
|
214
|
+
else:
|
215
|
+
extend_seq_lens = self.seq_lens - prefix_lens
|
216
|
+
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
217
|
+
|
218
|
+
def init_flashinfer_handlers(
|
219
|
+
self, model_runner, prefix_lens, flashinfer_use_ragged
|
220
|
+
):
|
221
|
+
update_flashinfer_indices(
|
222
|
+
self.forward_mode,
|
223
|
+
model_runner,
|
224
|
+
self.req_pool_indices,
|
225
|
+
self.seq_lens,
|
226
|
+
prefix_lens,
|
227
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
228
|
+
)
|
229
|
+
|
230
|
+
(
|
231
|
+
self.flashinfer_prefill_wrapper_ragged,
|
232
|
+
self.flashinfer_prefill_wrapper_paged,
|
233
|
+
self.flashinfer_decode_wrapper,
|
234
|
+
self.flashinfer_use_ragged,
|
235
|
+
) = (
|
236
|
+
model_runner.flashinfer_prefill_wrapper_ragged,
|
237
|
+
model_runner.flashinfer_prefill_wrapper_paged,
|
238
|
+
model_runner.flashinfer_decode_wrapper,
|
239
|
+
flashinfer_use_ragged,
|
240
|
+
)
|
241
|
+
|
242
|
+
|
243
|
+
def update_flashinfer_indices(
|
244
|
+
forward_mode,
|
245
|
+
model_runner,
|
246
|
+
req_pool_indices,
|
247
|
+
seq_lens,
|
248
|
+
prefix_lens,
|
249
|
+
flashinfer_decode_wrapper=None,
|
250
|
+
flashinfer_use_ragged=False,
|
251
|
+
):
|
252
|
+
"""Init auxiliary variables for FlashInfer attention backend."""
|
253
|
+
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
254
|
+
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
255
|
+
head_dim = model_runner.model_config.head_dim
|
256
|
+
batch_size = len(req_pool_indices)
|
257
|
+
|
258
|
+
if flashinfer_use_ragged:
|
259
|
+
paged_kernel_lens = prefix_lens
|
260
|
+
else:
|
261
|
+
paged_kernel_lens = seq_lens
|
262
|
+
|
263
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
264
|
+
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
265
|
+
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
266
|
+
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
267
|
+
kv_indices = torch.cat(
|
268
|
+
[
|
269
|
+
model_runner.req_to_token_pool.req_to_token[
|
270
|
+
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
271
|
+
]
|
272
|
+
for i in range(batch_size)
|
273
|
+
],
|
274
|
+
dim=0,
|
275
|
+
).contiguous()
|
276
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
277
|
+
|
278
|
+
if forward_mode == ForwardMode.DECODE:
|
279
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
280
|
+
if flashinfer_decode_wrapper is None:
|
281
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
282
|
+
|
283
|
+
flashinfer_decode_wrapper.end_forward()
|
284
|
+
flashinfer_decode_wrapper.begin_forward(
|
285
|
+
kv_indptr,
|
286
|
+
kv_indices,
|
287
|
+
kv_last_page_len,
|
288
|
+
num_qo_heads,
|
289
|
+
num_kv_heads,
|
290
|
+
head_dim,
|
291
|
+
1,
|
292
|
+
)
|
293
|
+
else:
|
294
|
+
# extend part
|
295
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
296
|
+
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
297
|
+
|
298
|
+
if flashinfer_use_ragged:
|
299
|
+
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
300
|
+
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
301
|
+
qo_indptr,
|
302
|
+
qo_indptr,
|
303
|
+
num_qo_heads,
|
304
|
+
num_kv_heads,
|
305
|
+
head_dim,
|
306
|
+
)
|
307
|
+
|
308
|
+
# cached part
|
309
|
+
model_runner.flashinfer_prefill_wrapper_paged.end_forward()
|
310
|
+
model_runner.flashinfer_prefill_wrapper_paged.begin_forward(
|
311
|
+
qo_indptr,
|
312
|
+
kv_indptr,
|
313
|
+
kv_indices,
|
314
|
+
kv_last_page_len,
|
315
|
+
num_qo_heads,
|
316
|
+
num_kv_heads,
|
317
|
+
head_dim,
|
318
|
+
1,
|
319
|
+
)
|
@@ -41,21 +41,18 @@ from vllm.distributed import (
|
|
41
41
|
from vllm.model_executor.models import ModelRegistry
|
42
42
|
|
43
43
|
from sglang.global_config import global_config
|
44
|
-
from sglang.srt.managers.schedule_batch import
|
45
|
-
Batch,
|
46
|
-
ForwardMode,
|
47
|
-
InputMetadata,
|
48
|
-
global_server_args_dict,
|
49
|
-
)
|
44
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
45
|
from sglang.srt.mem_cache.memory_pool import (
|
51
46
|
MHATokenToKVPool,
|
52
47
|
MLATokenToKVPool,
|
53
48
|
ReqToTokenPool,
|
54
49
|
)
|
55
50
|
from sglang.srt.model_config import AttentionArch
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
56
52
|
from sglang.srt.server_args import ServerArgs
|
57
53
|
from sglang.srt.utils import (
|
58
54
|
get_available_gpu_memory,
|
55
|
+
is_generation_model,
|
59
56
|
is_llama3_405b_fp8,
|
60
57
|
is_multimodal_model,
|
61
58
|
monkey_patch_vllm_dummy_weight_loader,
|
@@ -134,10 +131,12 @@ class ModelRunner:
|
|
134
131
|
server_args.max_total_tokens,
|
135
132
|
)
|
136
133
|
self.init_cublas()
|
137
|
-
self.
|
134
|
+
self.init_flashinfer()
|
138
135
|
|
139
|
-
|
140
|
-
|
136
|
+
if self.is_generation:
|
137
|
+
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
138
|
+
# Capture cuda graphs
|
139
|
+
self.init_cuda_graphs()
|
141
140
|
|
142
141
|
def load_model(self):
|
143
142
|
logger.info(
|
@@ -188,6 +187,10 @@ class ModelRunner:
|
|
188
187
|
scheduler_config=None,
|
189
188
|
cache_config=None,
|
190
189
|
)
|
190
|
+
self.is_generation = is_generation_model(
|
191
|
+
self.model_config.hf_config.architectures
|
192
|
+
)
|
193
|
+
|
191
194
|
logger.info(
|
192
195
|
f"[gpu={self.gpu_id}] Load weight end. "
|
193
196
|
f"type={type(self.model).__name__}, "
|
@@ -291,7 +294,7 @@ class ModelRunner:
|
|
291
294
|
c = a @ b
|
292
295
|
return c
|
293
296
|
|
294
|
-
def
|
297
|
+
def init_flashinfer(self):
|
295
298
|
if self.server_args.disable_flashinfer:
|
296
299
|
self.flashinfer_prefill_wrapper_ragged = None
|
297
300
|
self.flashinfer_prefill_wrapper_paged = None
|
@@ -350,65 +353,42 @@ class ModelRunner:
|
|
350
353
|
)
|
351
354
|
|
352
355
|
@torch.inference_mode()
|
353
|
-
def forward_decode(self, batch:
|
356
|
+
def forward_decode(self, batch: ScheduleBatch):
|
354
357
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
355
358
|
return self.cuda_graph_runner.replay(batch)
|
356
359
|
|
357
|
-
input_metadata = InputMetadata.
|
358
|
-
self,
|
359
|
-
forward_mode=ForwardMode.DECODE,
|
360
|
-
req_pool_indices=batch.req_pool_indices,
|
361
|
-
seq_lens=batch.seq_lens,
|
362
|
-
prefix_lens=batch.prefix_lens,
|
363
|
-
position_ids_offsets=batch.position_ids_offsets,
|
364
|
-
out_cache_loc=batch.out_cache_loc,
|
365
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
366
|
-
return_logprob=batch.return_logprob,
|
360
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
361
|
+
self, batch, ForwardMode.DECODE
|
367
362
|
)
|
363
|
+
|
368
364
|
return self.model.forward(
|
369
365
|
batch.input_ids, input_metadata.positions, input_metadata
|
370
366
|
)
|
371
367
|
|
372
368
|
@torch.inference_mode()
|
373
|
-
def forward_extend(self, batch:
|
374
|
-
input_metadata = InputMetadata.
|
375
|
-
self,
|
376
|
-
forward_mode=ForwardMode.EXTEND,
|
377
|
-
req_pool_indices=batch.req_pool_indices,
|
378
|
-
seq_lens=batch.seq_lens,
|
379
|
-
prefix_lens=batch.prefix_lens,
|
380
|
-
position_ids_offsets=batch.position_ids_offsets,
|
381
|
-
out_cache_loc=batch.out_cache_loc,
|
382
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
383
|
-
return_logprob=batch.return_logprob,
|
369
|
+
def forward_extend(self, batch: ScheduleBatch):
|
370
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
371
|
+
self, batch, forward_mode=ForwardMode.EXTEND
|
384
372
|
)
|
385
373
|
return self.model.forward(
|
386
374
|
batch.input_ids, input_metadata.positions, input_metadata
|
387
375
|
)
|
388
376
|
|
389
377
|
@torch.inference_mode()
|
390
|
-
def forward_extend_multi_modal(self, batch:
|
391
|
-
input_metadata = InputMetadata.
|
392
|
-
self,
|
393
|
-
forward_mode=ForwardMode.EXTEND,
|
394
|
-
req_pool_indices=batch.req_pool_indices,
|
395
|
-
seq_lens=batch.seq_lens,
|
396
|
-
prefix_lens=batch.prefix_lens,
|
397
|
-
position_ids_offsets=batch.position_ids_offsets,
|
398
|
-
out_cache_loc=batch.out_cache_loc,
|
399
|
-
return_logprob=batch.return_logprob,
|
400
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
378
|
+
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
379
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
380
|
+
self, batch, forward_mode=ForwardMode.EXTEND
|
401
381
|
)
|
402
382
|
return self.model.forward(
|
403
383
|
batch.input_ids,
|
404
384
|
input_metadata.positions,
|
405
385
|
input_metadata,
|
406
|
-
|
407
|
-
|
408
|
-
|
386
|
+
input_metadata.pixel_values,
|
387
|
+
input_metadata.image_sizes,
|
388
|
+
input_metadata.image_offsets,
|
409
389
|
)
|
410
390
|
|
411
|
-
def forward(self, batch:
|
391
|
+
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
412
392
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
413
393
|
return self.forward_extend_multi_modal(batch)
|
414
394
|
elif forward_mode == ForwardMode.DECODE:
|
@@ -433,8 +413,10 @@ def import_model_classes():
|
|
433
413
|
entry, list
|
434
414
|
): # To support multiple model classes in one module
|
435
415
|
for tmp in entry:
|
416
|
+
assert tmp.__name__ not in model_arch_name_to_cls
|
436
417
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
437
418
|
else:
|
419
|
+
assert entry.__name__ not in model_arch_name_to_cls
|
438
420
|
model_arch_name_to_cls[entry.__name__] = entry
|
439
421
|
|
440
422
|
# compat: some models such as chatglm has incorrect class set in config.json
|
@@ -444,6 +426,7 @@ def import_model_classes():
|
|
444
426
|
):
|
445
427
|
for remap in module.EntryClassRemapping:
|
446
428
|
if isinstance(remap, tuple) and len(remap) == 2:
|
429
|
+
assert remap[0] not in model_arch_name_to_cls
|
447
430
|
model_arch_name_to_cls[remap[0]] = remap[1]
|
448
431
|
|
449
432
|
return model_arch_name_to_cls
|
sglang/srt/models/chatglm.py
CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs import ChatGLMConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
LoraConfig = None
|
51
51
|
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,7 +64,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
-
from sglang.srt.model_executor.
|
67
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
68
68
|
|
69
69
|
|
70
70
|
@torch.compile
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,7 +45,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
|
51
51
|
class DbrxRouter(nn.Module):
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
46
46
|
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
50
|
|
51
51
|
|
52
52
|
class DeepseekMLP(nn.Module):
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -46,7 +46,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
-
from sglang.srt.model_executor.
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
50
|
|
51
51
|
|
52
52
|
class DeepseekV2MLP(nn.Module):
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,7 +37,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
37
37
|
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.model_executor.
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
41
41
|
|
42
42
|
|
43
43
|
class GemmaMLP(nn.Module):
|
sglang/srt/models/gemma2.py
CHANGED
@@ -38,11 +38,10 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
38
38
|
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
39
39
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
40
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
42
41
|
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
-
from sglang.srt.model_executor.
|
44
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
45
|
|
47
46
|
|
48
47
|
class GemmaRMSNorm(CustomOp):
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
35
35
|
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
39
39
|
|
40
40
|
|
41
41
|
class GPTBigCodeAttention(nn.Module):
|
sglang/srt/models/grok.py
CHANGED
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
|
|
52
52
|
from sglang.srt.layers.fused_moe import fused_moe
|
53
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
54
|
from sglang.srt.layers.radix_attention import RadixAttention
|
55
|
-
from sglang.srt.model_executor.
|
55
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
56
56
|
|
57
57
|
use_fused = True
|
58
58
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -23,8 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.config import CacheConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
27
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
28
26
|
from vllm.model_executor.layers.linear import (
|
29
27
|
MergedColumnParallelLinear,
|
30
28
|
QKVParallelLinear,
|
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
38
36
|
)
|
39
37
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
38
|
|
39
|
+
from sglang.srt.layers.activation import SiluAndMul
|
40
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class InternLM2MLP(nn.Module):
|
47
|
-
|
48
47
|
def __init__(
|
49
48
|
self,
|
50
49
|
hidden_size: int,
|
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
|
|
74
73
|
|
75
74
|
|
76
75
|
class InternLM2Attention(nn.Module):
|
77
|
-
|
78
76
|
def __init__(
|
79
77
|
self,
|
80
78
|
hidden_size: int,
|
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
|
|
150
148
|
|
151
149
|
|
152
150
|
class InternLMDecoderLayer(nn.Module):
|
153
|
-
|
154
151
|
def __init__(
|
155
152
|
self,
|
156
153
|
config: PretrainedConfig,
|
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
|
|
207
204
|
|
208
205
|
|
209
206
|
class InternLM2Model(nn.Module):
|
210
|
-
|
211
207
|
def __init__(
|
212
208
|
self,
|
213
209
|
config: PretrainedConfig,
|
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
|
|
254
250
|
|
255
251
|
|
256
252
|
class InternLM2ForCausalLM(nn.Module):
|
257
|
-
|
258
253
|
def __init__(
|
259
254
|
self,
|
260
255
|
config: PretrainedConfig,
|
sglang/srt/models/llama2.py
CHANGED
@@ -24,8 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import LlamaConfig
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
29
27
|
from vllm.model_executor.layers.linear import (
|
30
28
|
MergedColumnParallelLinear,
|
31
29
|
QKVParallelLinear,
|
@@ -39,9 +37,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
39
37
|
)
|
40
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
39
|
|
42
|
-
from sglang.srt.layers.
|
40
|
+
from sglang.srt.layers.activation import SiluAndMul
|
41
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.model_executor.
|
44
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
45
|
|
46
46
|
|
47
47
|
class LlamaMLP(nn.Module):
|
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
|
|
310
310
|
positions: torch.Tensor,
|
311
311
|
input_metadata: InputMetadata,
|
312
312
|
input_embeds: torch.Tensor = None,
|
313
|
-
) ->
|
313
|
+
) -> LogitProcessorOutput:
|
314
314
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
315
315
|
return self.logits_processor(
|
316
316
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
27
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
28
|
-
from sglang.srt.model_executor.
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
29
|
from sglang.srt.models.llama2 import LlamaModel
|
30
30
|
|
31
31
|
|