sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,462 @@
|
|
1
|
+
import importlib
|
2
|
+
import importlib.resources
|
3
|
+
import logging
|
4
|
+
import pkgutil
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from functools import lru_cache
|
7
|
+
from typing import List, Optional, Type
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import torch
|
11
|
+
import torch.nn as nn
|
12
|
+
from vllm.config import DeviceConfig, LoadConfig
|
13
|
+
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
+
from vllm.distributed import initialize_model_parallel, init_distributed_environment
|
15
|
+
from vllm.model_executor.model_loader import get_model
|
16
|
+
from vllm.model_executor.models import ModelRegistry
|
17
|
+
|
18
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
|
19
|
+
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
+
from sglang.srt.server_args import ServerArgs
|
21
|
+
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
|
22
|
+
|
23
|
+
|
24
|
+
logger = logging.getLogger("srt.model_runner")
|
25
|
+
|
26
|
+
# for server args in model endpoints
|
27
|
+
global_server_args_dict = {}
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class InputMetadata:
|
32
|
+
model_runner: "ModelRunner"
|
33
|
+
forward_mode: ForwardMode
|
34
|
+
batch_size: int
|
35
|
+
total_num_tokens: int
|
36
|
+
max_seq_len: int
|
37
|
+
req_pool_indices: torch.Tensor
|
38
|
+
start_loc: torch.Tensor
|
39
|
+
seq_lens: torch.Tensor
|
40
|
+
prefix_lens: torch.Tensor
|
41
|
+
positions: torch.Tensor
|
42
|
+
req_to_token_pool: ReqToTokenPool
|
43
|
+
token_to_kv_pool: TokenToKVPool
|
44
|
+
|
45
|
+
# for extend
|
46
|
+
extend_seq_lens: torch.Tensor = None
|
47
|
+
extend_start_loc: torch.Tensor = None
|
48
|
+
max_extend_len: int = 0
|
49
|
+
|
50
|
+
out_cache_loc: torch.Tensor = None
|
51
|
+
out_cache_cont_start: torch.Tensor = None
|
52
|
+
out_cache_cont_end: torch.Tensor = None
|
53
|
+
|
54
|
+
other_kv_index: torch.Tensor = None
|
55
|
+
return_logprob: bool = False
|
56
|
+
top_logprobs_nums: List[int] = None
|
57
|
+
|
58
|
+
# for flashinfer
|
59
|
+
qo_indptr: torch.Tensor = None
|
60
|
+
kv_indptr: torch.Tensor = None
|
61
|
+
kv_indices: torch.Tensor = None
|
62
|
+
kv_last_page_len: torch.Tensor = None
|
63
|
+
prefill_wrapper = None
|
64
|
+
decode_wrapper = None
|
65
|
+
|
66
|
+
def init_flashinfer_args(self, tp_size):
|
67
|
+
from flashinfer import (
|
68
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
69
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
70
|
+
)
|
71
|
+
|
72
|
+
self.kv_indptr = torch.zeros(
|
73
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
74
|
+
)
|
75
|
+
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
76
|
+
self.kv_last_page_len = torch.ones(
|
77
|
+
(self.batch_size,), dtype=torch.int32, device="cuda"
|
78
|
+
)
|
79
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
80
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
81
|
+
self.kv_indices = torch.cat(
|
82
|
+
[
|
83
|
+
self.req_to_token_pool.req_to_token[
|
84
|
+
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
85
|
+
]
|
86
|
+
for i in range(self.batch_size)
|
87
|
+
],
|
88
|
+
dim=0,
|
89
|
+
).contiguous()
|
90
|
+
|
91
|
+
workspace_buffer = torch.empty(
|
92
|
+
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
93
|
+
)
|
94
|
+
if (
|
95
|
+
self.forward_mode == ForwardMode.PREFILL
|
96
|
+
or self.forward_mode == ForwardMode.EXTEND
|
97
|
+
):
|
98
|
+
self.qo_indptr = torch.zeros(
|
99
|
+
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
100
|
+
)
|
101
|
+
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
102
|
+
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
103
|
+
workspace_buffer, "NHD"
|
104
|
+
)
|
105
|
+
args = [
|
106
|
+
self.qo_indptr,
|
107
|
+
self.kv_indptr,
|
108
|
+
self.kv_indices,
|
109
|
+
self.kv_last_page_len,
|
110
|
+
self.model_runner.model_config.num_attention_heads // tp_size,
|
111
|
+
self.model_runner.model_config.num_key_value_heads // tp_size,
|
112
|
+
self.model_runner.model_config.head_dim,
|
113
|
+
]
|
114
|
+
|
115
|
+
self.prefill_wrapper.begin_forward(*args)
|
116
|
+
else:
|
117
|
+
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
118
|
+
workspace_buffer, "NHD"
|
119
|
+
)
|
120
|
+
self.decode_wrapper.begin_forward(
|
121
|
+
self.kv_indptr,
|
122
|
+
self.kv_indices,
|
123
|
+
self.kv_last_page_len,
|
124
|
+
self.model_runner.model_config.num_attention_heads // tp_size,
|
125
|
+
self.model_runner.model_config.num_key_value_heads // tp_size,
|
126
|
+
self.model_runner.model_config.head_dim,
|
127
|
+
1,
|
128
|
+
"NONE",
|
129
|
+
"float16",
|
130
|
+
)
|
131
|
+
|
132
|
+
def init_extend_args(self):
|
133
|
+
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
134
|
+
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
135
|
+
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
136
|
+
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
def create(
|
140
|
+
cls,
|
141
|
+
model_runner,
|
142
|
+
tp_size,
|
143
|
+
forward_mode,
|
144
|
+
req_pool_indices,
|
145
|
+
seq_lens,
|
146
|
+
prefix_lens,
|
147
|
+
position_ids_offsets,
|
148
|
+
out_cache_loc,
|
149
|
+
out_cache_cont_start=None,
|
150
|
+
out_cache_cont_end=None,
|
151
|
+
top_logprobs_nums=None,
|
152
|
+
return_logprob=False,
|
153
|
+
):
|
154
|
+
batch_size = len(req_pool_indices)
|
155
|
+
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
156
|
+
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
157
|
+
total_num_tokens = int(torch.sum(seq_lens))
|
158
|
+
max_seq_len = int(torch.max(seq_lens))
|
159
|
+
|
160
|
+
if forward_mode == ForwardMode.DECODE:
|
161
|
+
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
162
|
+
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
163
|
+
req_pool_indices[0], seq_lens[0] - 1
|
164
|
+
].item()
|
165
|
+
else:
|
166
|
+
seq_lens_cpu = seq_lens.cpu().numpy()
|
167
|
+
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
168
|
+
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
169
|
+
positions = torch.tensor(
|
170
|
+
np.concatenate(
|
171
|
+
[
|
172
|
+
np.arange(
|
173
|
+
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
174
|
+
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
175
|
+
)
|
176
|
+
for i in range(batch_size)
|
177
|
+
],
|
178
|
+
axis=0,
|
179
|
+
),
|
180
|
+
device="cuda",
|
181
|
+
)
|
182
|
+
other_kv_index = None
|
183
|
+
|
184
|
+
ret = cls(
|
185
|
+
model_runner=model_runner,
|
186
|
+
forward_mode=forward_mode,
|
187
|
+
batch_size=batch_size,
|
188
|
+
total_num_tokens=total_num_tokens,
|
189
|
+
max_seq_len=max_seq_len,
|
190
|
+
req_pool_indices=req_pool_indices,
|
191
|
+
start_loc=start_loc,
|
192
|
+
seq_lens=seq_lens,
|
193
|
+
prefix_lens=prefix_lens,
|
194
|
+
positions=positions,
|
195
|
+
req_to_token_pool=model_runner.req_to_token_pool,
|
196
|
+
token_to_kv_pool=model_runner.token_to_kv_pool,
|
197
|
+
out_cache_loc=out_cache_loc,
|
198
|
+
out_cache_cont_start=out_cache_cont_start,
|
199
|
+
out_cache_cont_end=out_cache_cont_end,
|
200
|
+
other_kv_index=other_kv_index,
|
201
|
+
return_logprob=return_logprob,
|
202
|
+
top_logprobs_nums=top_logprobs_nums,
|
203
|
+
)
|
204
|
+
|
205
|
+
if forward_mode == ForwardMode.EXTEND:
|
206
|
+
ret.init_extend_args()
|
207
|
+
|
208
|
+
if global_server_args_dict.get("enable_flashinfer", False):
|
209
|
+
ret.init_flashinfer_args(tp_size)
|
210
|
+
|
211
|
+
return ret
|
212
|
+
|
213
|
+
|
214
|
+
class ModelRunner:
|
215
|
+
def __init__(
|
216
|
+
self,
|
217
|
+
model_config,
|
218
|
+
mem_fraction_static: float,
|
219
|
+
gpu_id: int,
|
220
|
+
tp_rank: int,
|
221
|
+
tp_size: int,
|
222
|
+
nccl_port: int,
|
223
|
+
server_args: ServerArgs,
|
224
|
+
):
|
225
|
+
self.model_config = model_config
|
226
|
+
self.mem_fraction_static = mem_fraction_static
|
227
|
+
self.gpu_id = gpu_id
|
228
|
+
self.tp_rank = tp_rank
|
229
|
+
self.tp_size = tp_size
|
230
|
+
self.nccl_port = nccl_port
|
231
|
+
self.server_args = server_args
|
232
|
+
|
233
|
+
global global_server_args_dict
|
234
|
+
global_server_args_dict = {
|
235
|
+
"enable_flashinfer": server_args.enable_flashinfer,
|
236
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
237
|
+
}
|
238
|
+
|
239
|
+
# Init torch distributed
|
240
|
+
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
241
|
+
torch.cuda.set_device(self.gpu_id)
|
242
|
+
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
243
|
+
monkey_patch_vllm_p2p_access_check()
|
244
|
+
init_distributed_environment(
|
245
|
+
backend="nccl",
|
246
|
+
world_size=self.tp_size,
|
247
|
+
rank=self.tp_rank,
|
248
|
+
local_rank=self.gpu_id,
|
249
|
+
distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
250
|
+
)
|
251
|
+
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
252
|
+
total_gpu_memory = get_available_gpu_memory(
|
253
|
+
self.gpu_id, distributed=self.tp_size > 1
|
254
|
+
)
|
255
|
+
|
256
|
+
if self.tp_size > 1:
|
257
|
+
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
|
258
|
+
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
259
|
+
raise ValueError(
|
260
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
261
|
+
)
|
262
|
+
|
263
|
+
self.load_model()
|
264
|
+
self.init_memory_pool(total_gpu_memory)
|
265
|
+
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
266
|
+
|
267
|
+
def load_model(self):
|
268
|
+
logger.info(
|
269
|
+
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
270
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
271
|
+
)
|
272
|
+
|
273
|
+
device_config = DeviceConfig()
|
274
|
+
load_config = LoadConfig(load_format=self.server_args.load_format)
|
275
|
+
vllm_model_config = VllmModelConfig(
|
276
|
+
model=self.server_args.model_path,
|
277
|
+
quantization=self.server_args.quantization,
|
278
|
+
tokenizer=None,
|
279
|
+
tokenizer_mode=None,
|
280
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
281
|
+
dtype=torch.float16,
|
282
|
+
seed=42,
|
283
|
+
skip_tokenizer_init=True,
|
284
|
+
)
|
285
|
+
if self.model_config.model_overide_args is not None:
|
286
|
+
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
287
|
+
|
288
|
+
self.model = get_model(
|
289
|
+
model_config=vllm_model_config,
|
290
|
+
device_config=device_config,
|
291
|
+
load_config=load_config,
|
292
|
+
lora_config=None,
|
293
|
+
vision_language_config=None,
|
294
|
+
parallel_config=None,
|
295
|
+
scheduler_config=None,
|
296
|
+
cache_config=None,
|
297
|
+
)
|
298
|
+
logger.info(
|
299
|
+
f"[gpu_id={self.gpu_id}] Load weight end. "
|
300
|
+
f"type={type(self.model).__name__}, "
|
301
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
302
|
+
)
|
303
|
+
|
304
|
+
def profile_max_num_token(self, total_gpu_memory):
|
305
|
+
available_gpu_memory = get_available_gpu_memory(
|
306
|
+
self.gpu_id, distributed=self.tp_size > 1
|
307
|
+
)
|
308
|
+
head_dim = self.model_config.head_dim
|
309
|
+
head_num = self.model_config.num_key_value_heads // self.tp_size
|
310
|
+
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
311
|
+
rest_memory = available_gpu_memory - total_gpu_memory * (
|
312
|
+
1 - self.mem_fraction_static
|
313
|
+
)
|
314
|
+
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
315
|
+
return max_num_token
|
316
|
+
|
317
|
+
def init_memory_pool(self, total_gpu_memory):
|
318
|
+
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
319
|
+
|
320
|
+
if self.max_total_num_tokens <= 0:
|
321
|
+
raise RuntimeError(
|
322
|
+
"Not enought memory. Please try to increase --mem-fraction-static."
|
323
|
+
)
|
324
|
+
|
325
|
+
self.req_to_token_pool = ReqToTokenPool(
|
326
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
327
|
+
self.model_config.context_len + 8,
|
328
|
+
)
|
329
|
+
self.token_to_kv_pool = TokenToKVPool(
|
330
|
+
self.max_total_num_tokens,
|
331
|
+
dtype=torch.float16,
|
332
|
+
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
333
|
+
head_dim=self.model_config.head_dim,
|
334
|
+
layer_num=self.model_config.num_hidden_layers,
|
335
|
+
)
|
336
|
+
logger.info(
|
337
|
+
f"[gpu_id={self.gpu_id}] Memory pool end. "
|
338
|
+
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
339
|
+
)
|
340
|
+
|
341
|
+
@torch.inference_mode()
|
342
|
+
def forward_prefill(self, batch: Batch):
|
343
|
+
input_metadata = InputMetadata.create(
|
344
|
+
self,
|
345
|
+
forward_mode=ForwardMode.PREFILL,
|
346
|
+
tp_size=self.tp_size,
|
347
|
+
req_pool_indices=batch.req_pool_indices,
|
348
|
+
seq_lens=batch.seq_lens,
|
349
|
+
prefix_lens=batch.prefix_lens,
|
350
|
+
position_ids_offsets=batch.position_ids_offsets,
|
351
|
+
out_cache_loc=batch.out_cache_loc,
|
352
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
353
|
+
return_logprob=batch.return_logprob,
|
354
|
+
)
|
355
|
+
return self.model.forward(
|
356
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
357
|
+
)
|
358
|
+
|
359
|
+
@torch.inference_mode()
|
360
|
+
def forward_extend(self, batch: Batch):
|
361
|
+
input_metadata = InputMetadata.create(
|
362
|
+
self,
|
363
|
+
forward_mode=ForwardMode.EXTEND,
|
364
|
+
tp_size=self.tp_size,
|
365
|
+
req_pool_indices=batch.req_pool_indices,
|
366
|
+
seq_lens=batch.seq_lens,
|
367
|
+
prefix_lens=batch.prefix_lens,
|
368
|
+
position_ids_offsets=batch.position_ids_offsets,
|
369
|
+
out_cache_loc=batch.out_cache_loc,
|
370
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
371
|
+
return_logprob=batch.return_logprob,
|
372
|
+
)
|
373
|
+
return self.model.forward(
|
374
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
375
|
+
)
|
376
|
+
|
377
|
+
@torch.inference_mode()
|
378
|
+
def forward_decode(self, batch: Batch):
|
379
|
+
input_metadata = InputMetadata.create(
|
380
|
+
self,
|
381
|
+
forward_mode=ForwardMode.DECODE,
|
382
|
+
tp_size=self.tp_size,
|
383
|
+
req_pool_indices=batch.req_pool_indices,
|
384
|
+
seq_lens=batch.seq_lens,
|
385
|
+
prefix_lens=batch.prefix_lens,
|
386
|
+
position_ids_offsets=batch.position_ids_offsets,
|
387
|
+
out_cache_loc=batch.out_cache_loc,
|
388
|
+
out_cache_cont_start=batch.out_cache_cont_start,
|
389
|
+
out_cache_cont_end=batch.out_cache_cont_end,
|
390
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
391
|
+
return_logprob=batch.return_logprob,
|
392
|
+
)
|
393
|
+
return self.model.forward(
|
394
|
+
batch.input_ids, input_metadata.positions, input_metadata
|
395
|
+
)
|
396
|
+
|
397
|
+
@torch.inference_mode()
|
398
|
+
def forward_extend_multi_modal(self, batch: Batch):
|
399
|
+
input_metadata = InputMetadata.create(
|
400
|
+
self,
|
401
|
+
forward_mode=ForwardMode.EXTEND,
|
402
|
+
tp_size=self.tp_size,
|
403
|
+
req_pool_indices=batch.req_pool_indices,
|
404
|
+
seq_lens=batch.seq_lens,
|
405
|
+
prefix_lens=batch.prefix_lens,
|
406
|
+
position_ids_offsets=batch.position_ids_offsets,
|
407
|
+
out_cache_loc=batch.out_cache_loc,
|
408
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
409
|
+
return_logprob=batch.return_logprob,
|
410
|
+
)
|
411
|
+
return self.model.forward(
|
412
|
+
batch.input_ids,
|
413
|
+
input_metadata.positions,
|
414
|
+
input_metadata,
|
415
|
+
batch.pixel_values,
|
416
|
+
batch.image_sizes,
|
417
|
+
batch.image_offsets,
|
418
|
+
)
|
419
|
+
|
420
|
+
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
421
|
+
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
422
|
+
return self.forward_extend_multi_modal(batch)
|
423
|
+
elif forward_mode == ForwardMode.DECODE:
|
424
|
+
return self.forward_decode(batch)
|
425
|
+
elif forward_mode == ForwardMode.EXTEND:
|
426
|
+
return self.forward_extend(batch)
|
427
|
+
elif forward_mode == ForwardMode.PREFILL:
|
428
|
+
return self.forward_prefill(batch)
|
429
|
+
else:
|
430
|
+
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
431
|
+
|
432
|
+
|
433
|
+
@lru_cache()
|
434
|
+
def import_model_classes():
|
435
|
+
model_arch_name_to_cls = {}
|
436
|
+
package_name = "sglang.srt.models"
|
437
|
+
package = importlib.import_module(package_name)
|
438
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
439
|
+
if not ispkg:
|
440
|
+
module = importlib.import_module(name)
|
441
|
+
if hasattr(module, "EntryClass"):
|
442
|
+
entry = module.EntryClass
|
443
|
+
if isinstance(entry, list): # To support multiple model classes in one module
|
444
|
+
for tmp in entry:
|
445
|
+
model_arch_name_to_cls[tmp.__name__] = tmp
|
446
|
+
else:
|
447
|
+
model_arch_name_to_cls[entry.__name__] = entry
|
448
|
+
return model_arch_name_to_cls
|
449
|
+
|
450
|
+
|
451
|
+
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
452
|
+
model_arch_name_to_cls = import_model_classes()
|
453
|
+
if model_arch not in model_arch_name_to_cls:
|
454
|
+
raise ValueError(
|
455
|
+
f"Unsupported architectures: {model_arch}. "
|
456
|
+
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
457
|
+
)
|
458
|
+
return model_arch_name_to_cls[model_arch]
|
459
|
+
|
460
|
+
|
461
|
+
# Monkey patch model loader
|
462
|
+
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|