sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +6 -4
- sglang/bench_serving.py +46 -22
- sglang/lang/compiler.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/layers/activation.py +33 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +6 -1
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +5 -0
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +110 -87
- sglang/srt/managers/tokenizer_manager.py +193 -111
- sglang/srt/managers/tp_worker.py +289 -352
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +168 -105
- sglang/srt/model_executor/model_runner.py +24 -37
- sglang/srt/models/gemma2.py +0 -1
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/qwen2_moe.py +0 -11
- sglang/srt/openai_api/adapter.py +155 -27
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -4
- sglang/srt/server.py +69 -15
- sglang/srt/server_args.py +26 -19
- sglang/srt/utils.py +31 -13
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +63 -63
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +4 -2
- sglang/test/test_utils.py +20 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
- sglang-0.2.12.dist-info/RECORD +112 -0
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,17 @@ limitations under the License.
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
from dataclasses import dataclass
|
18
18
|
from enum import IntEnum, auto
|
19
|
-
from typing import List
|
19
|
+
from typing import TYPE_CHECKING, List
|
20
20
|
|
21
21
|
import numpy as np
|
22
22
|
import torch
|
23
23
|
|
24
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
24
25
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
25
26
|
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
29
|
+
|
26
30
|
|
27
31
|
class ForwardMode(IntEnum):
|
28
32
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -39,25 +43,33 @@ class InputMetadata:
|
|
39
43
|
|
40
44
|
forward_mode: ForwardMode
|
41
45
|
batch_size: int
|
42
|
-
total_num_tokens: int
|
43
46
|
req_pool_indices: torch.Tensor
|
44
47
|
seq_lens: torch.Tensor
|
45
|
-
positions: torch.Tensor
|
46
48
|
req_to_token_pool: ReqToTokenPool
|
47
49
|
token_to_kv_pool: BaseTokenToKVPool
|
48
50
|
|
49
|
-
# For extend
|
50
|
-
extend_seq_lens: torch.Tensor
|
51
|
-
extend_start_loc: torch.Tensor
|
52
|
-
extend_no_prefix: bool
|
53
|
-
|
54
51
|
# Output location of the KV cache
|
55
|
-
out_cache_loc: torch.Tensor
|
52
|
+
out_cache_loc: torch.Tensor
|
53
|
+
|
54
|
+
total_num_tokens: int = None
|
55
|
+
|
56
|
+
# Position information
|
57
|
+
positions: torch.Tensor = None
|
58
|
+
|
59
|
+
# For extend
|
60
|
+
extend_seq_lens: torch.Tensor = None
|
61
|
+
extend_start_loc: torch.Tensor = None
|
62
|
+
extend_no_prefix: bool = None
|
56
63
|
|
57
64
|
# Output options
|
58
65
|
return_logprob: bool = False
|
59
66
|
top_logprobs_nums: List[int] = None
|
60
67
|
|
68
|
+
# For multimodal
|
69
|
+
pixel_values: List[torch.Tensor] = None
|
70
|
+
image_sizes: List[List[int]] = None
|
71
|
+
image_offsets: List[int] = None
|
72
|
+
|
61
73
|
# Trition attention backend
|
62
74
|
triton_max_seq_len: int = 0
|
63
75
|
triton_max_extend_len: int = 0
|
@@ -70,107 +82,171 @@ class InputMetadata:
|
|
70
82
|
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
83
|
flashinfer_use_ragged: bool = False
|
72
84
|
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
out_cache_loc,
|
83
|
-
top_logprobs_nums=None,
|
84
|
-
return_logprob=False,
|
85
|
-
skip_flashinfer_init=False,
|
86
|
-
):
|
87
|
-
flashinfer_use_ragged = False
|
88
|
-
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
89
|
-
if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
|
90
|
-
flashinfer_use_ragged = True
|
91
|
-
init_flashinfer_args(
|
92
|
-
forward_mode,
|
93
|
-
model_runner,
|
94
|
-
req_pool_indices,
|
95
|
-
seq_lens,
|
96
|
-
prefix_lens,
|
97
|
-
model_runner.flashinfer_decode_wrapper,
|
98
|
-
flashinfer_use_ragged,
|
85
|
+
def init_multimuldal_info(self, batch: ScheduleBatch):
|
86
|
+
reqs = batch.reqs
|
87
|
+
self.pixel_values = [r.pixel_values for r in reqs]
|
88
|
+
self.image_sizes = [r.image_size for r in reqs]
|
89
|
+
self.image_offsets = [
|
90
|
+
(
|
91
|
+
(r.image_offset - len(r.prefix_indices))
|
92
|
+
if r.image_offset is not None
|
93
|
+
else 0
|
99
94
|
)
|
95
|
+
for r in reqs
|
96
|
+
]
|
100
97
|
|
101
|
-
|
98
|
+
def compute_positions(self, batch: ScheduleBatch):
|
99
|
+
position_ids_offsets = batch.position_ids_offsets
|
102
100
|
|
103
|
-
if forward_mode == ForwardMode.DECODE:
|
104
|
-
|
105
|
-
|
106
|
-
if not model_runner.server_args.disable_flashinfer:
|
107
|
-
# This variable is not needed in this case,
|
108
|
-
# we do not compute it to make it compatbile with cuda graph.
|
109
|
-
total_num_tokens = None
|
101
|
+
if self.forward_mode == ForwardMode.DECODE:
|
102
|
+
if True:
|
103
|
+
self.positions = self.seq_lens - 1
|
110
104
|
else:
|
111
|
-
|
105
|
+
# Deprecated
|
106
|
+
self.positions = (self.seq_lens - 1) + position_ids_offsets
|
112
107
|
else:
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
)
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
108
|
+
if True:
|
109
|
+
self.positions = torch.tensor(
|
110
|
+
np.concatenate(
|
111
|
+
[
|
112
|
+
np.arange(len(req.prefix_indices), len(req.fill_ids))
|
113
|
+
for req in batch.reqs
|
114
|
+
],
|
115
|
+
axis=0,
|
116
|
+
),
|
117
|
+
device="cuda",
|
118
|
+
)
|
119
|
+
else:
|
120
|
+
# Deprecated
|
121
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
122
|
+
self.positions = torch.tensor(
|
123
|
+
np.concatenate(
|
124
|
+
[
|
125
|
+
np.arange(
|
126
|
+
len(req.prefix_indices) + position_ids_offsets_cpu[i],
|
127
|
+
len(req.fill_ids) + position_ids_offsets_cpu[i],
|
128
|
+
)
|
129
|
+
for i, req in enumerate(batch.reqs)
|
130
|
+
],
|
131
|
+
axis=0,
|
132
|
+
),
|
133
|
+
device="cuda",
|
134
|
+
)
|
135
|
+
|
136
|
+
# Positions should be in long type
|
137
|
+
self.positions = self.positions.to(torch.int64)
|
138
|
+
|
139
|
+
def compute_extend_infos(self, batch: ScheduleBatch):
|
140
|
+
if self.forward_mode == ForwardMode.DECODE:
|
141
|
+
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
|
142
|
+
else:
|
143
|
+
extend_lens_cpu = [
|
144
|
+
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
|
145
|
+
]
|
146
|
+
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
147
|
+
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
148
|
+
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
149
|
+
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
|
134
150
|
|
151
|
+
@classmethod
|
152
|
+
def from_schedule_batch(
|
153
|
+
cls,
|
154
|
+
model_runner: "ModelRunner",
|
155
|
+
batch: ScheduleBatch,
|
156
|
+
forward_mode: ForwardMode,
|
157
|
+
):
|
135
158
|
ret = cls(
|
136
159
|
forward_mode=forward_mode,
|
137
|
-
batch_size=batch_size,
|
138
|
-
|
139
|
-
|
140
|
-
seq_lens=seq_lens,
|
141
|
-
positions=positions,
|
160
|
+
batch_size=batch.batch_size(),
|
161
|
+
req_pool_indices=batch.req_pool_indices,
|
162
|
+
seq_lens=batch.seq_lens,
|
142
163
|
req_to_token_pool=model_runner.req_to_token_pool,
|
143
164
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
144
|
-
out_cache_loc=out_cache_loc,
|
145
|
-
|
146
|
-
|
147
|
-
extend_no_prefix=extend_no_prefix,
|
148
|
-
return_logprob=return_logprob,
|
149
|
-
top_logprobs_nums=top_logprobs_nums,
|
150
|
-
flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
|
151
|
-
flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
|
152
|
-
flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
|
153
|
-
flashinfer_use_ragged=flashinfer_use_ragged,
|
165
|
+
out_cache_loc=batch.out_cache_loc,
|
166
|
+
return_logprob=batch.return_logprob,
|
167
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
154
168
|
)
|
155
169
|
|
170
|
+
ret.compute_positions(batch)
|
171
|
+
|
172
|
+
ret.compute_extend_infos(batch)
|
173
|
+
|
174
|
+
if (
|
175
|
+
forward_mode != ForwardMode.DECODE
|
176
|
+
or model_runner.server_args.disable_flashinfer
|
177
|
+
):
|
178
|
+
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
179
|
+
|
180
|
+
if forward_mode != ForwardMode.DECODE:
|
181
|
+
ret.init_multimuldal_info(batch)
|
182
|
+
|
183
|
+
prefix_lens = None
|
184
|
+
if forward_mode != ForwardMode.DECODE:
|
185
|
+
prefix_lens = torch.tensor(
|
186
|
+
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
|
187
|
+
)
|
188
|
+
|
156
189
|
if model_runner.server_args.disable_flashinfer:
|
157
|
-
(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
190
|
+
ret.init_triton_args(batch, prefix_lens)
|
191
|
+
|
192
|
+
flashinfer_use_ragged = False
|
193
|
+
if not model_runner.server_args.disable_flashinfer:
|
194
|
+
if (
|
195
|
+
forward_mode != ForwardMode.DECODE
|
196
|
+
and int(torch.sum(ret.seq_lens)) > 4096
|
197
|
+
):
|
198
|
+
flashinfer_use_ragged = True
|
199
|
+
ret.init_flashinfer_handlers(
|
200
|
+
model_runner, prefix_lens, flashinfer_use_ragged
|
201
|
+
)
|
163
202
|
|
164
203
|
return ret
|
165
204
|
|
205
|
+
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
206
|
+
"""Init auxiliary variables for triton attention backend."""
|
207
|
+
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
208
|
+
self.triton_prefix_lens = prefix_lens
|
209
|
+
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
210
|
+
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
211
|
+
|
212
|
+
if self.forward_mode == ForwardMode.DECODE:
|
213
|
+
self.triton_max_extend_len = None
|
214
|
+
else:
|
215
|
+
extend_seq_lens = self.seq_lens - prefix_lens
|
216
|
+
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
|
217
|
+
|
218
|
+
def init_flashinfer_handlers(
|
219
|
+
self, model_runner, prefix_lens, flashinfer_use_ragged
|
220
|
+
):
|
221
|
+
update_flashinfer_indices(
|
222
|
+
self.forward_mode,
|
223
|
+
model_runner,
|
224
|
+
self.req_pool_indices,
|
225
|
+
self.seq_lens,
|
226
|
+
prefix_lens,
|
227
|
+
flashinfer_use_ragged=flashinfer_use_ragged,
|
228
|
+
)
|
229
|
+
|
230
|
+
(
|
231
|
+
self.flashinfer_prefill_wrapper_ragged,
|
232
|
+
self.flashinfer_prefill_wrapper_paged,
|
233
|
+
self.flashinfer_decode_wrapper,
|
234
|
+
self.flashinfer_use_ragged,
|
235
|
+
) = (
|
236
|
+
model_runner.flashinfer_prefill_wrapper_ragged,
|
237
|
+
model_runner.flashinfer_prefill_wrapper_paged,
|
238
|
+
model_runner.flashinfer_decode_wrapper,
|
239
|
+
flashinfer_use_ragged,
|
240
|
+
)
|
241
|
+
|
166
242
|
|
167
|
-
def
|
243
|
+
def update_flashinfer_indices(
|
168
244
|
forward_mode,
|
169
245
|
model_runner,
|
170
246
|
req_pool_indices,
|
171
247
|
seq_lens,
|
172
248
|
prefix_lens,
|
173
|
-
flashinfer_decode_wrapper,
|
249
|
+
flashinfer_decode_wrapper=None,
|
174
250
|
flashinfer_use_ragged=False,
|
175
251
|
):
|
176
252
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
@@ -178,7 +254,6 @@ def init_flashinfer_args(
|
|
178
254
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
179
255
|
head_dim = model_runner.model_config.head_dim
|
180
256
|
batch_size = len(req_pool_indices)
|
181
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
182
257
|
|
183
258
|
if flashinfer_use_ragged:
|
184
259
|
paged_kernel_lens = prefix_lens
|
@@ -201,6 +276,10 @@ def init_flashinfer_args(
|
|
201
276
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
202
277
|
|
203
278
|
if forward_mode == ForwardMode.DECODE:
|
279
|
+
# CUDA graph uses different flashinfer_decode_wrapper
|
280
|
+
if flashinfer_decode_wrapper is None:
|
281
|
+
flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
|
282
|
+
|
204
283
|
flashinfer_decode_wrapper.end_forward()
|
205
284
|
flashinfer_decode_wrapper.begin_forward(
|
206
285
|
kv_indptr,
|
@@ -238,19 +317,3 @@ def init_flashinfer_args(
|
|
238
317
|
head_dim,
|
239
318
|
1,
|
240
319
|
)
|
241
|
-
|
242
|
-
|
243
|
-
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
244
|
-
"""Init auxiliary variables for triton attention backend."""
|
245
|
-
batch_size = len(seq_lens)
|
246
|
-
max_seq_len = int(torch.max(seq_lens))
|
247
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
248
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
249
|
-
|
250
|
-
if forward_mode == ForwardMode.DECODE:
|
251
|
-
max_extend_len = None
|
252
|
-
else:
|
253
|
-
extend_seq_lens = seq_lens - prefix_lens
|
254
|
-
max_extend_len = int(torch.max(extend_seq_lens))
|
255
|
-
|
256
|
-
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
52
52
|
from sglang.srt.server_args import ServerArgs
|
53
53
|
from sglang.srt.utils import (
|
54
54
|
get_available_gpu_memory,
|
55
|
+
is_generation_model,
|
55
56
|
is_llama3_405b_fp8,
|
56
57
|
is_multimodal_model,
|
57
58
|
monkey_patch_vllm_dummy_weight_loader,
|
@@ -130,10 +131,12 @@ class ModelRunner:
|
|
130
131
|
server_args.max_total_tokens,
|
131
132
|
)
|
132
133
|
self.init_cublas()
|
133
|
-
self.
|
134
|
+
self.init_flashinfer()
|
134
135
|
|
135
|
-
|
136
|
-
|
136
|
+
if self.is_generation:
|
137
|
+
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
|
138
|
+
# Capture cuda graphs
|
139
|
+
self.init_cuda_graphs()
|
137
140
|
|
138
141
|
def load_model(self):
|
139
142
|
logger.info(
|
@@ -184,6 +187,10 @@ class ModelRunner:
|
|
184
187
|
scheduler_config=None,
|
185
188
|
cache_config=None,
|
186
189
|
)
|
190
|
+
self.is_generation = is_generation_model(
|
191
|
+
self.model_config.hf_config.architectures
|
192
|
+
)
|
193
|
+
|
187
194
|
logger.info(
|
188
195
|
f"[gpu={self.gpu_id}] Load weight end. "
|
189
196
|
f"type={type(self.model).__name__}, "
|
@@ -287,7 +294,7 @@ class ModelRunner:
|
|
287
294
|
c = a @ b
|
288
295
|
return c
|
289
296
|
|
290
|
-
def
|
297
|
+
def init_flashinfer(self):
|
291
298
|
if self.server_args.disable_flashinfer:
|
292
299
|
self.flashinfer_prefill_wrapper_ragged = None
|
293
300
|
self.flashinfer_prefill_wrapper_paged = None
|
@@ -350,33 +357,18 @@ class ModelRunner:
|
|
350
357
|
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
351
358
|
return self.cuda_graph_runner.replay(batch)
|
352
359
|
|
353
|
-
input_metadata = InputMetadata.
|
354
|
-
self,
|
355
|
-
forward_mode=ForwardMode.DECODE,
|
356
|
-
req_pool_indices=batch.req_pool_indices,
|
357
|
-
seq_lens=batch.seq_lens,
|
358
|
-
prefix_lens=batch.prefix_lens,
|
359
|
-
position_ids_offsets=batch.position_ids_offsets,
|
360
|
-
out_cache_loc=batch.out_cache_loc,
|
361
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
362
|
-
return_logprob=batch.return_logprob,
|
360
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
361
|
+
self, batch, ForwardMode.DECODE
|
363
362
|
)
|
363
|
+
|
364
364
|
return self.model.forward(
|
365
365
|
batch.input_ids, input_metadata.positions, input_metadata
|
366
366
|
)
|
367
367
|
|
368
368
|
@torch.inference_mode()
|
369
369
|
def forward_extend(self, batch: ScheduleBatch):
|
370
|
-
input_metadata = InputMetadata.
|
371
|
-
self,
|
372
|
-
forward_mode=ForwardMode.EXTEND,
|
373
|
-
req_pool_indices=batch.req_pool_indices,
|
374
|
-
seq_lens=batch.seq_lens,
|
375
|
-
prefix_lens=batch.prefix_lens,
|
376
|
-
position_ids_offsets=batch.position_ids_offsets,
|
377
|
-
out_cache_loc=batch.out_cache_loc,
|
378
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
379
|
-
return_logprob=batch.return_logprob,
|
370
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
371
|
+
self, batch, forward_mode=ForwardMode.EXTEND
|
380
372
|
)
|
381
373
|
return self.model.forward(
|
382
374
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -384,24 +376,16 @@ class ModelRunner:
|
|
384
376
|
|
385
377
|
@torch.inference_mode()
|
386
378
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
387
|
-
input_metadata = InputMetadata.
|
388
|
-
self,
|
389
|
-
forward_mode=ForwardMode.EXTEND,
|
390
|
-
req_pool_indices=batch.req_pool_indices,
|
391
|
-
seq_lens=batch.seq_lens,
|
392
|
-
prefix_lens=batch.prefix_lens,
|
393
|
-
position_ids_offsets=batch.position_ids_offsets,
|
394
|
-
out_cache_loc=batch.out_cache_loc,
|
395
|
-
return_logprob=batch.return_logprob,
|
396
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
379
|
+
input_metadata = InputMetadata.from_schedule_batch(
|
380
|
+
self, batch, forward_mode=ForwardMode.EXTEND
|
397
381
|
)
|
398
382
|
return self.model.forward(
|
399
383
|
batch.input_ids,
|
400
384
|
input_metadata.positions,
|
401
385
|
input_metadata,
|
402
|
-
|
403
|
-
|
404
|
-
|
386
|
+
input_metadata.pixel_values,
|
387
|
+
input_metadata.image_sizes,
|
388
|
+
input_metadata.image_offsets,
|
405
389
|
)
|
406
390
|
|
407
391
|
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
@@ -429,8 +413,10 @@ def import_model_classes():
|
|
429
413
|
entry, list
|
430
414
|
): # To support multiple model classes in one module
|
431
415
|
for tmp in entry:
|
416
|
+
assert tmp.__name__ not in model_arch_name_to_cls
|
432
417
|
model_arch_name_to_cls[tmp.__name__] = tmp
|
433
418
|
else:
|
419
|
+
assert entry.__name__ not in model_arch_name_to_cls
|
434
420
|
model_arch_name_to_cls[entry.__name__] = entry
|
435
421
|
|
436
422
|
# compat: some models such as chatglm has incorrect class set in config.json
|
@@ -440,6 +426,7 @@ def import_model_classes():
|
|
440
426
|
):
|
441
427
|
for remap in module.EntryClassRemapping:
|
442
428
|
if isinstance(remap, tuple) and len(remap) == 2:
|
429
|
+
assert remap[0] not in model_arch_name_to_cls
|
443
430
|
model_arch_name_to_cls[remap[0]] = remap[1]
|
444
431
|
|
445
432
|
return model_arch_name_to_cls
|
sglang/srt/models/gemma2.py
CHANGED
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
38
38
|
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
39
39
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
40
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
|
-
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
42
41
|
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
sglang/srt/models/internlm2.py
CHANGED
@@ -23,8 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.config import CacheConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
27
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
28
26
|
from vllm.model_executor.layers.linear import (
|
29
27
|
MergedColumnParallelLinear,
|
30
28
|
QKVParallelLinear,
|
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
38
36
|
)
|
39
37
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
38
|
|
39
|
+
from sglang.srt.layers.activation import SiluAndMul
|
40
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class InternLM2MLP(nn.Module):
|
47
|
-
|
48
47
|
def __init__(
|
49
48
|
self,
|
50
49
|
hidden_size: int,
|
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
|
|
74
73
|
|
75
74
|
|
76
75
|
class InternLM2Attention(nn.Module):
|
77
|
-
|
78
76
|
def __init__(
|
79
77
|
self,
|
80
78
|
hidden_size: int,
|
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
|
|
150
148
|
|
151
149
|
|
152
150
|
class InternLMDecoderLayer(nn.Module):
|
153
|
-
|
154
151
|
def __init__(
|
155
152
|
self,
|
156
153
|
config: PretrainedConfig,
|
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
|
|
207
204
|
|
208
205
|
|
209
206
|
class InternLM2Model(nn.Module):
|
210
|
-
|
211
207
|
def __init__(
|
212
208
|
self,
|
213
209
|
config: PretrainedConfig,
|
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
|
|
254
250
|
|
255
251
|
|
256
252
|
class InternLM2ForCausalLM(nn.Module):
|
257
|
-
|
258
253
|
def __init__(
|
259
254
|
self,
|
260
255
|
config: PretrainedConfig,
|
sglang/srt/models/llama2.py
CHANGED
@@ -24,8 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import LlamaConfig
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
29
27
|
from vllm.model_executor.layers.linear import (
|
30
28
|
MergedColumnParallelLinear,
|
31
29
|
QKVParallelLinear,
|
@@ -39,7 +37,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
39
37
|
)
|
40
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
39
|
|
42
|
-
from sglang.srt.layers.
|
40
|
+
from sglang.srt.layers.activation import SiluAndMul
|
41
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
45
|
|
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
|
|
310
310
|
positions: torch.Tensor,
|
311
311
|
input_metadata: InputMetadata,
|
312
312
|
input_embeds: torch.Tensor = None,
|
313
|
-
) ->
|
313
|
+
) -> LogitProcessorOutput:
|
314
314
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
315
315
|
return self.logits_processor(
|
316
316
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
@@ -0,0 +1,88 @@
|
|
1
|
+
from typing import Iterable, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from transformers import LlamaConfig
|
6
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
|
+
|
8
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
|
+
from sglang.srt.model_executor.model_runner import InputMetadata
|
10
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
|
11
|
+
|
12
|
+
|
13
|
+
class LlamaEmbeddingModel(nn.Module):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
config: LlamaConfig,
|
17
|
+
quant_config=None,
|
18
|
+
cache_config=None,
|
19
|
+
efficient_weight_load=False,
|
20
|
+
) -> None:
|
21
|
+
super().__init__()
|
22
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
23
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
24
|
+
|
25
|
+
@torch.no_grad()
|
26
|
+
def forward(
|
27
|
+
self,
|
28
|
+
input_ids: torch.Tensor,
|
29
|
+
positions: torch.Tensor,
|
30
|
+
input_metadata: InputMetadata,
|
31
|
+
input_embeds: torch.Tensor = None,
|
32
|
+
) -> EmbeddingPoolerOutput:
|
33
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
34
|
+
return self.pooler(hidden_states, input_metadata)
|
35
|
+
|
36
|
+
def load_weights(
|
37
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
38
|
+
):
|
39
|
+
stacked_params_mapping = [
|
40
|
+
# (param_name, shard_name, shard_id)
|
41
|
+
("qkv_proj", "q_proj", "q"),
|
42
|
+
("qkv_proj", "k_proj", "k"),
|
43
|
+
("qkv_proj", "v_proj", "v"),
|
44
|
+
("gate_up_proj", "gate_proj", 0),
|
45
|
+
("gate_up_proj", "up_proj", 1),
|
46
|
+
]
|
47
|
+
params_dict = dict(self.model.named_parameters())
|
48
|
+
|
49
|
+
def load_weights_per_param(name, loaded_weight):
|
50
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
51
|
+
return
|
52
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
53
|
+
# Models trained using ColossalAI may include these tensors in
|
54
|
+
# the checkpoint. Skip them.
|
55
|
+
return
|
56
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
57
|
+
if weight_name not in name:
|
58
|
+
continue
|
59
|
+
name = name.replace(weight_name, param_name)
|
60
|
+
# Skip loading extra bias for GPTQ models.
|
61
|
+
if name.endswith(".bias") and name not in params_dict:
|
62
|
+
continue
|
63
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
64
|
+
continue
|
65
|
+
param = params_dict[name]
|
66
|
+
weight_loader = param.weight_loader
|
67
|
+
weight_loader(param, loaded_weight, shard_id)
|
68
|
+
break
|
69
|
+
else:
|
70
|
+
# Skip loading extra bias for GPTQ models.
|
71
|
+
if name.endswith(".bias") and name not in params_dict:
|
72
|
+
return
|
73
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
74
|
+
return
|
75
|
+
param = params_dict[name]
|
76
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
77
|
+
weight_loader(param, loaded_weight)
|
78
|
+
|
79
|
+
if name is None or loaded_weight is None:
|
80
|
+
for name, loaded_weight in weights:
|
81
|
+
load_weights_per_param(name, loaded_weight)
|
82
|
+
else:
|
83
|
+
load_weights_per_param(name, loaded_weight)
|
84
|
+
|
85
|
+
|
86
|
+
EntryClass = LlamaEmbeddingModel
|
87
|
+
# compat: e5-mistral model.config class == MistralModel
|
88
|
+
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|