sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,445 +0,0 @@
|
|
1
|
-
import importlib
|
2
|
-
import importlib.resources
|
3
|
-
import logging
|
4
|
-
import pkgutil
|
5
|
-
from dataclasses import dataclass
|
6
|
-
from functools import lru_cache
|
7
|
-
from typing import List, Optional, Type
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
import torch
|
11
|
-
import torch.nn as nn
|
12
|
-
from vllm.config import DeviceConfig, LoadConfig
|
13
|
-
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
-
from vllm.distributed import initialize_model_parallel
|
15
|
-
from vllm.model_executor.model_loader import get_model
|
16
|
-
from vllm.model_executor.models import ModelRegistry
|
17
|
-
|
18
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
19
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
-
from sglang.srt.server_args import ServerArgs
|
21
|
-
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
22
|
-
|
23
|
-
|
24
|
-
logger = logging.getLogger("model_runner")
|
25
|
-
|
26
|
-
# for server args in model endpoints
|
27
|
-
global_server_args_dict = {}
|
28
|
-
|
29
|
-
|
30
|
-
@dataclass
|
31
|
-
class InputMetadata:
|
32
|
-
model_runner: "ModelRunner"
|
33
|
-
forward_mode: ForwardMode
|
34
|
-
batch_size: int
|
35
|
-
total_num_tokens: int
|
36
|
-
max_seq_len: int
|
37
|
-
req_pool_indices: torch.Tensor
|
38
|
-
start_loc: torch.Tensor
|
39
|
-
seq_lens: torch.Tensor
|
40
|
-
prefix_lens: torch.Tensor
|
41
|
-
positions: torch.Tensor
|
42
|
-
req_to_token_pool: ReqToTokenPool
|
43
|
-
token_to_kv_pool: TokenToKVPool
|
44
|
-
|
45
|
-
# for extend
|
46
|
-
extend_seq_lens: torch.Tensor = None
|
47
|
-
extend_start_loc: torch.Tensor = None
|
48
|
-
max_extend_len: int = 0
|
49
|
-
|
50
|
-
out_cache_loc: torch.Tensor = None
|
51
|
-
out_cache_cont_start: torch.Tensor = None
|
52
|
-
out_cache_cont_end: torch.Tensor = None
|
53
|
-
|
54
|
-
other_kv_index: torch.Tensor = None
|
55
|
-
return_logprob: bool = False
|
56
|
-
top_logprobs_nums: List[int] = None
|
57
|
-
|
58
|
-
# for flashinfer
|
59
|
-
qo_indptr: torch.Tensor = None
|
60
|
-
kv_indptr: torch.Tensor = None
|
61
|
-
kv_indices: torch.Tensor = None
|
62
|
-
kv_last_page_len: torch.Tensor = None
|
63
|
-
prefill_wrapper = None
|
64
|
-
decode_wrapper = None
|
65
|
-
|
66
|
-
def init_flashinfer_args(self, tp_size):
|
67
|
-
from flashinfer import (
|
68
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
69
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
70
|
-
)
|
71
|
-
|
72
|
-
self.kv_indptr = torch.zeros(
|
73
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
74
|
-
)
|
75
|
-
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
76
|
-
self.kv_last_page_len = torch.ones(
|
77
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
78
|
-
)
|
79
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
80
|
-
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
81
|
-
self.kv_indices = torch.cat(
|
82
|
-
[
|
83
|
-
self.req_to_token_pool.req_to_token[
|
84
|
-
req_pool_indices_cpu[i], : seq_lens_cpu[i]
|
85
|
-
]
|
86
|
-
for i in range(self.batch_size)
|
87
|
-
],
|
88
|
-
dim=0,
|
89
|
-
).contiguous()
|
90
|
-
|
91
|
-
workspace_buffer = torch.empty(
|
92
|
-
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
93
|
-
)
|
94
|
-
if (
|
95
|
-
self.forward_mode == ForwardMode.PREFILL
|
96
|
-
or self.forward_mode == ForwardMode.EXTEND
|
97
|
-
):
|
98
|
-
self.qo_indptr = torch.zeros(
|
99
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
100
|
-
)
|
101
|
-
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
102
|
-
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
103
|
-
workspace_buffer, "NHD"
|
104
|
-
)
|
105
|
-
args = [
|
106
|
-
self.qo_indptr,
|
107
|
-
self.kv_indptr,
|
108
|
-
self.kv_indices,
|
109
|
-
self.kv_last_page_len,
|
110
|
-
self.model_runner.model_config.num_attention_heads // tp_size,
|
111
|
-
self.model_runner.model_config.num_key_value_heads // tp_size,
|
112
|
-
self.model_runner.model_config.head_dim,
|
113
|
-
]
|
114
|
-
|
115
|
-
self.prefill_wrapper.begin_forward(*args)
|
116
|
-
else:
|
117
|
-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
118
|
-
workspace_buffer, "NHD"
|
119
|
-
)
|
120
|
-
self.decode_wrapper.begin_forward(
|
121
|
-
self.kv_indptr,
|
122
|
-
self.kv_indices,
|
123
|
-
self.kv_last_page_len,
|
124
|
-
self.model_runner.model_config.num_attention_heads // tp_size,
|
125
|
-
self.model_runner.model_config.num_key_value_heads // tp_size,
|
126
|
-
self.model_runner.model_config.head_dim,
|
127
|
-
1,
|
128
|
-
"NONE",
|
129
|
-
"float16",
|
130
|
-
)
|
131
|
-
|
132
|
-
def init_extend_args(self):
|
133
|
-
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
134
|
-
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
135
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
136
|
-
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
137
|
-
|
138
|
-
@classmethod
|
139
|
-
def create(
|
140
|
-
cls,
|
141
|
-
model_runner,
|
142
|
-
tp_size,
|
143
|
-
forward_mode,
|
144
|
-
req_pool_indices,
|
145
|
-
seq_lens,
|
146
|
-
prefix_lens,
|
147
|
-
position_ids_offsets,
|
148
|
-
out_cache_loc,
|
149
|
-
out_cache_cont_start=None,
|
150
|
-
out_cache_cont_end=None,
|
151
|
-
top_logprobs_nums=None,
|
152
|
-
return_logprob=False,
|
153
|
-
):
|
154
|
-
batch_size = len(req_pool_indices)
|
155
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
156
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
157
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
158
|
-
max_seq_len = int(torch.max(seq_lens))
|
159
|
-
|
160
|
-
if forward_mode == ForwardMode.DECODE:
|
161
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
162
|
-
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
163
|
-
req_pool_indices[0], seq_lens[0] - 1
|
164
|
-
].item()
|
165
|
-
else:
|
166
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
167
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
168
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
169
|
-
positions = torch.tensor(
|
170
|
-
np.concatenate(
|
171
|
-
[
|
172
|
-
np.arange(
|
173
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
174
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
175
|
-
)
|
176
|
-
for i in range(batch_size)
|
177
|
-
],
|
178
|
-
axis=0,
|
179
|
-
),
|
180
|
-
device="cuda",
|
181
|
-
)
|
182
|
-
other_kv_index = None
|
183
|
-
|
184
|
-
ret = cls(
|
185
|
-
model_runner=model_runner,
|
186
|
-
forward_mode=forward_mode,
|
187
|
-
batch_size=batch_size,
|
188
|
-
total_num_tokens=total_num_tokens,
|
189
|
-
max_seq_len=max_seq_len,
|
190
|
-
req_pool_indices=req_pool_indices,
|
191
|
-
start_loc=start_loc,
|
192
|
-
seq_lens=seq_lens,
|
193
|
-
prefix_lens=prefix_lens,
|
194
|
-
positions=positions,
|
195
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
196
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
197
|
-
out_cache_loc=out_cache_loc,
|
198
|
-
out_cache_cont_start=out_cache_cont_start,
|
199
|
-
out_cache_cont_end=out_cache_cont_end,
|
200
|
-
other_kv_index=other_kv_index,
|
201
|
-
return_logprob=return_logprob,
|
202
|
-
top_logprobs_nums=top_logprobs_nums,
|
203
|
-
)
|
204
|
-
|
205
|
-
if forward_mode == ForwardMode.EXTEND:
|
206
|
-
ret.init_extend_args()
|
207
|
-
|
208
|
-
if global_server_args_dict.get("enable_flashinfer", False):
|
209
|
-
ret.init_flashinfer_args(tp_size)
|
210
|
-
|
211
|
-
return ret
|
212
|
-
|
213
|
-
|
214
|
-
class ModelRunner:
|
215
|
-
def __init__(
|
216
|
-
self,
|
217
|
-
model_config,
|
218
|
-
mem_fraction_static,
|
219
|
-
tp_rank,
|
220
|
-
tp_size,
|
221
|
-
nccl_port,
|
222
|
-
server_args: ServerArgs,
|
223
|
-
):
|
224
|
-
self.model_config = model_config
|
225
|
-
self.mem_fraction_static = mem_fraction_static
|
226
|
-
self.tp_rank = tp_rank
|
227
|
-
self.tp_size = tp_size
|
228
|
-
self.nccl_port = nccl_port
|
229
|
-
self.server_args = server_args
|
230
|
-
|
231
|
-
global global_server_args_dict
|
232
|
-
global_server_args_dict = {
|
233
|
-
"enable_flashinfer": server_args.enable_flashinfer,
|
234
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
235
|
-
}
|
236
|
-
|
237
|
-
# Init torch distributed
|
238
|
-
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
|
239
|
-
torch.cuda.set_device(self.tp_rank)
|
240
|
-
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
241
|
-
torch.distributed.init_process_group(
|
242
|
-
backend="nccl",
|
243
|
-
world_size=self.tp_size,
|
244
|
-
rank=self.tp_rank,
|
245
|
-
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
246
|
-
)
|
247
|
-
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
248
|
-
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
249
|
-
|
250
|
-
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
251
|
-
|
252
|
-
if self.tp_size > 1:
|
253
|
-
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
254
|
-
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
255
|
-
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
256
|
-
|
257
|
-
self.load_model()
|
258
|
-
self.init_memory_pool(total_gpu_memory)
|
259
|
-
|
260
|
-
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
261
|
-
|
262
|
-
def load_model(self):
|
263
|
-
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
264
|
-
|
265
|
-
device_config = DeviceConfig()
|
266
|
-
load_config = LoadConfig(load_format=self.server_args.load_format)
|
267
|
-
vllm_model_config = VllmModelConfig(
|
268
|
-
model=self.server_args.model_path,
|
269
|
-
quantization=self.server_args.quantization,
|
270
|
-
tokenizer=None,
|
271
|
-
tokenizer_mode=None,
|
272
|
-
trust_remote_code=self.server_args.trust_remote_code,
|
273
|
-
dtype=torch.float16,
|
274
|
-
seed=42,
|
275
|
-
skip_tokenizer_init=True,
|
276
|
-
)
|
277
|
-
if self.model_config.model_overide_args is not None:
|
278
|
-
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
279
|
-
|
280
|
-
self.model = get_model(
|
281
|
-
model_config=vllm_model_config,
|
282
|
-
device_config=device_config,
|
283
|
-
load_config=load_config,
|
284
|
-
lora_config=None,
|
285
|
-
vision_language_config=None,
|
286
|
-
parallel_config=None,
|
287
|
-
scheduler_config=None,
|
288
|
-
)
|
289
|
-
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
290
|
-
f"Type={type(self.model).__name__}. "
|
291
|
-
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
292
|
-
|
293
|
-
def profile_max_num_token(self, total_gpu_memory):
|
294
|
-
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
295
|
-
head_dim = self.model_config.head_dim
|
296
|
-
head_num = self.model_config.num_key_value_heads // self.tp_size
|
297
|
-
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
298
|
-
rest_memory = available_gpu_memory - total_gpu_memory * (
|
299
|
-
1 - self.mem_fraction_static
|
300
|
-
)
|
301
|
-
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
302
|
-
return max_num_token
|
303
|
-
|
304
|
-
def init_memory_pool(self, total_gpu_memory):
|
305
|
-
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
306
|
-
|
307
|
-
if self.max_total_num_tokens <= 0:
|
308
|
-
raise RuntimeError(
|
309
|
-
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
310
|
-
)
|
311
|
-
|
312
|
-
self.req_to_token_pool = ReqToTokenPool(
|
313
|
-
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
314
|
-
self.model_config.context_len + 8,
|
315
|
-
)
|
316
|
-
self.token_to_kv_pool = TokenToKVPool(
|
317
|
-
self.max_total_num_tokens,
|
318
|
-
dtype=torch.float16,
|
319
|
-
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
320
|
-
head_dim=self.model_config.head_dim,
|
321
|
-
layer_num=self.model_config.num_hidden_layers,
|
322
|
-
)
|
323
|
-
|
324
|
-
@torch.inference_mode()
|
325
|
-
def forward_prefill(self, batch: Batch):
|
326
|
-
input_metadata = InputMetadata.create(
|
327
|
-
self,
|
328
|
-
forward_mode=ForwardMode.PREFILL,
|
329
|
-
tp_size=self.tp_size,
|
330
|
-
req_pool_indices=batch.req_pool_indices,
|
331
|
-
seq_lens=batch.seq_lens,
|
332
|
-
prefix_lens=batch.prefix_lens,
|
333
|
-
position_ids_offsets=batch.position_ids_offsets,
|
334
|
-
out_cache_loc=batch.out_cache_loc,
|
335
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
336
|
-
return_logprob=batch.return_logprob,
|
337
|
-
)
|
338
|
-
return self.model.forward(
|
339
|
-
batch.input_ids, input_metadata.positions, input_metadata
|
340
|
-
)
|
341
|
-
|
342
|
-
@torch.inference_mode()
|
343
|
-
def forward_extend(self, batch: Batch):
|
344
|
-
input_metadata = InputMetadata.create(
|
345
|
-
self,
|
346
|
-
forward_mode=ForwardMode.EXTEND,
|
347
|
-
tp_size=self.tp_size,
|
348
|
-
req_pool_indices=batch.req_pool_indices,
|
349
|
-
seq_lens=batch.seq_lens,
|
350
|
-
prefix_lens=batch.prefix_lens,
|
351
|
-
position_ids_offsets=batch.position_ids_offsets,
|
352
|
-
out_cache_loc=batch.out_cache_loc,
|
353
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
354
|
-
return_logprob=batch.return_logprob,
|
355
|
-
)
|
356
|
-
return self.model.forward(
|
357
|
-
batch.input_ids, input_metadata.positions, input_metadata
|
358
|
-
)
|
359
|
-
|
360
|
-
@torch.inference_mode()
|
361
|
-
def forward_decode(self, batch: Batch):
|
362
|
-
input_metadata = InputMetadata.create(
|
363
|
-
self,
|
364
|
-
forward_mode=ForwardMode.DECODE,
|
365
|
-
tp_size=self.tp_size,
|
366
|
-
req_pool_indices=batch.req_pool_indices,
|
367
|
-
seq_lens=batch.seq_lens,
|
368
|
-
prefix_lens=batch.prefix_lens,
|
369
|
-
position_ids_offsets=batch.position_ids_offsets,
|
370
|
-
out_cache_loc=batch.out_cache_loc,
|
371
|
-
out_cache_cont_start=batch.out_cache_cont_start,
|
372
|
-
out_cache_cont_end=batch.out_cache_cont_end,
|
373
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
374
|
-
return_logprob=batch.return_logprob,
|
375
|
-
)
|
376
|
-
return self.model.forward(
|
377
|
-
batch.input_ids, input_metadata.positions, input_metadata
|
378
|
-
)
|
379
|
-
|
380
|
-
@torch.inference_mode()
|
381
|
-
def forward_extend_multi_modal(self, batch: Batch):
|
382
|
-
input_metadata = InputMetadata.create(
|
383
|
-
self,
|
384
|
-
forward_mode=ForwardMode.EXTEND,
|
385
|
-
tp_size=self.tp_size,
|
386
|
-
req_pool_indices=batch.req_pool_indices,
|
387
|
-
seq_lens=batch.seq_lens,
|
388
|
-
prefix_lens=batch.prefix_lens,
|
389
|
-
position_ids_offsets=batch.position_ids_offsets,
|
390
|
-
out_cache_loc=batch.out_cache_loc,
|
391
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
392
|
-
return_logprob=batch.return_logprob,
|
393
|
-
)
|
394
|
-
return self.model.forward(
|
395
|
-
batch.input_ids,
|
396
|
-
input_metadata.positions,
|
397
|
-
input_metadata,
|
398
|
-
batch.pixel_values,
|
399
|
-
batch.image_sizes,
|
400
|
-
batch.image_offsets,
|
401
|
-
)
|
402
|
-
|
403
|
-
def forward(self, batch: Batch, forward_mode: ForwardMode):
|
404
|
-
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
405
|
-
return self.forward_extend_multi_modal(batch)
|
406
|
-
elif forward_mode == ForwardMode.DECODE:
|
407
|
-
return self.forward_decode(batch)
|
408
|
-
elif forward_mode == ForwardMode.EXTEND:
|
409
|
-
return self.forward_extend(batch)
|
410
|
-
elif forward_mode == ForwardMode.PREFILL:
|
411
|
-
return self.forward_prefill(batch)
|
412
|
-
else:
|
413
|
-
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
414
|
-
|
415
|
-
|
416
|
-
@lru_cache()
|
417
|
-
def import_model_classes():
|
418
|
-
model_arch_name_to_cls = {}
|
419
|
-
package_name = "sglang.srt.models"
|
420
|
-
package = importlib.import_module(package_name)
|
421
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
422
|
-
if not ispkg:
|
423
|
-
module = importlib.import_module(name)
|
424
|
-
if hasattr(module, "EntryClass"):
|
425
|
-
entry = module.EntryClass
|
426
|
-
if isinstance(entry, list): # To support multiple model classes in one module
|
427
|
-
for cls in entry:
|
428
|
-
model_arch_name_to_cls[cls.__name__] = cls
|
429
|
-
else:
|
430
|
-
model_arch_name_to_cls[entry.__name__] = entry
|
431
|
-
return model_arch_name_to_cls
|
432
|
-
|
433
|
-
|
434
|
-
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
435
|
-
model_arch_name_to_cls = import_model_classes()
|
436
|
-
if model_arch not in model_arch_name_to_cls:
|
437
|
-
raise ValueError(
|
438
|
-
f"Unsupported architectures: {model_arch}. "
|
439
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
440
|
-
)
|
441
|
-
return model_arch_name_to_cls[model_arch]
|
442
|
-
|
443
|
-
|
444
|
-
# Monkey patch model loader
|
445
|
-
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
@@ -1,267 +0,0 @@
|
|
1
|
-
import heapq
|
2
|
-
import time
|
3
|
-
from collections import defaultdict
|
4
|
-
|
5
|
-
import torch
|
6
|
-
|
7
|
-
|
8
|
-
class TreeNode:
|
9
|
-
def __init__(self):
|
10
|
-
self.children = defaultdict(TreeNode)
|
11
|
-
self.parent = None
|
12
|
-
self.key = None
|
13
|
-
self.value = None
|
14
|
-
self.lock_ref = 0
|
15
|
-
self.last_access_time = time.time()
|
16
|
-
|
17
|
-
def __lt__(self, other: "TreeNode"):
|
18
|
-
return self.last_access_time < other.last_access_time
|
19
|
-
|
20
|
-
|
21
|
-
def _key_match(key0, key1):
|
22
|
-
i = 0
|
23
|
-
for k0, k1 in zip(key0, key1):
|
24
|
-
if k0 != k1:
|
25
|
-
break
|
26
|
-
i += 1
|
27
|
-
return i
|
28
|
-
|
29
|
-
|
30
|
-
class RadixCache:
|
31
|
-
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
32
|
-
self.req_to_token_pool = req_to_token_pool
|
33
|
-
self.token_to_kv_pool = token_to_kv_pool
|
34
|
-
self.disable = disable
|
35
|
-
self.reset()
|
36
|
-
|
37
|
-
##### Public API #####
|
38
|
-
|
39
|
-
def reset(self):
|
40
|
-
self.root_node = TreeNode()
|
41
|
-
self.root_node.key = []
|
42
|
-
self.root_node.value = []
|
43
|
-
self.root_node.lock_ref = 1
|
44
|
-
self.evictable_size_ = 0
|
45
|
-
|
46
|
-
def match_prefix(self, key):
|
47
|
-
if self.disable:
|
48
|
-
return [], self.root_node
|
49
|
-
|
50
|
-
value = []
|
51
|
-
last_node = [self.root_node]
|
52
|
-
self._match_prefix_helper(self.root_node, key, value, last_node)
|
53
|
-
if value:
|
54
|
-
value = torch.concat(value)
|
55
|
-
else:
|
56
|
-
value = torch.tensor([], dtype=torch.int64)
|
57
|
-
return value, last_node[0]
|
58
|
-
|
59
|
-
def insert(self, key, value=None):
|
60
|
-
if self.disable:
|
61
|
-
return 0
|
62
|
-
|
63
|
-
if value is None:
|
64
|
-
value = [x for x in key]
|
65
|
-
return self._insert_helper(self.root_node, key, value)
|
66
|
-
|
67
|
-
def cache_req(
|
68
|
-
self,
|
69
|
-
token_ids,
|
70
|
-
last_uncached_pos,
|
71
|
-
req_pool_idx,
|
72
|
-
del_in_memory_pool=True,
|
73
|
-
old_last_node=None,
|
74
|
-
):
|
75
|
-
# Insert the request into radix cache
|
76
|
-
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
77
|
-
new_prefix_len = self.insert(token_ids, indices.clone())
|
78
|
-
|
79
|
-
if self.disable:
|
80
|
-
if del_in_memory_pool:
|
81
|
-
self.token_to_kv_pool.dec_refs(indices)
|
82
|
-
else:
|
83
|
-
return torch.tensor([], dtype=torch.int64), self.root_node
|
84
|
-
|
85
|
-
# Radix Cache takes one ref in memory pool
|
86
|
-
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
87
|
-
|
88
|
-
if del_in_memory_pool:
|
89
|
-
self.req_to_token_pool.free(req_pool_idx)
|
90
|
-
else:
|
91
|
-
cached_indices, new_last_node = self.match_prefix(token_ids)
|
92
|
-
assert len(cached_indices) == len(token_ids)
|
93
|
-
|
94
|
-
self.req_to_token_pool.req_to_token[
|
95
|
-
req_pool_idx, last_uncached_pos : len(cached_indices)
|
96
|
-
] = cached_indices[last_uncached_pos:]
|
97
|
-
self.dec_lock_ref(old_last_node)
|
98
|
-
self.inc_lock_ref(new_last_node)
|
99
|
-
return cached_indices, new_last_node
|
100
|
-
|
101
|
-
def pretty_print(self):
|
102
|
-
self._print_helper(self.root_node, 0)
|
103
|
-
print(f"#tokens: {self.total_size()}")
|
104
|
-
|
105
|
-
def total_size(self):
|
106
|
-
return self._total_size_helper(self.root_node)
|
107
|
-
|
108
|
-
def evict(self, num_tokens, evict_callback):
|
109
|
-
if self.disable:
|
110
|
-
return
|
111
|
-
|
112
|
-
leaves = self._collect_leaves()
|
113
|
-
heapq.heapify(leaves)
|
114
|
-
|
115
|
-
num_evicted = 0
|
116
|
-
while num_evicted < num_tokens and len(leaves):
|
117
|
-
x = heapq.heappop(leaves)
|
118
|
-
|
119
|
-
if x == self.root_node:
|
120
|
-
break
|
121
|
-
if x.lock_ref > 0:
|
122
|
-
continue
|
123
|
-
|
124
|
-
num_evicted += evict_callback(x.value)
|
125
|
-
self._delete_leaf(x)
|
126
|
-
|
127
|
-
if len(x.parent.children) == 0:
|
128
|
-
heapq.heappush(leaves, x.parent)
|
129
|
-
|
130
|
-
def inc_lock_ref(self, node: TreeNode):
|
131
|
-
delta = 0
|
132
|
-
while node != self.root_node:
|
133
|
-
if node.lock_ref == 0:
|
134
|
-
self.evictable_size_ -= len(node.value)
|
135
|
-
delta -= len(node.value)
|
136
|
-
node.lock_ref += 1
|
137
|
-
node = node.parent
|
138
|
-
return delta
|
139
|
-
|
140
|
-
def dec_lock_ref(self, node: TreeNode):
|
141
|
-
delta = 0
|
142
|
-
while node != self.root_node:
|
143
|
-
if node.lock_ref == 1:
|
144
|
-
self.evictable_size_ += len(node.value)
|
145
|
-
delta += len(node.value)
|
146
|
-
node.lock_ref -= 1
|
147
|
-
node = node.parent
|
148
|
-
return delta
|
149
|
-
|
150
|
-
def evictable_size(self):
|
151
|
-
return self.evictable_size_
|
152
|
-
|
153
|
-
##### Internal Helper Functions #####
|
154
|
-
|
155
|
-
def _match_prefix_helper(self, node, key, value, last_node):
|
156
|
-
node.last_access_time = time.time()
|
157
|
-
if len(key) == 0:
|
158
|
-
return
|
159
|
-
|
160
|
-
if key[0] in node.children.keys():
|
161
|
-
child = node.children[key[0]]
|
162
|
-
prefix_len = _key_match(child.key, key)
|
163
|
-
if prefix_len < len(child.key):
|
164
|
-
new_node = self._split_node(child.key, child, prefix_len)
|
165
|
-
value.append(new_node.value)
|
166
|
-
last_node[0] = new_node
|
167
|
-
else:
|
168
|
-
value.append(child.value)
|
169
|
-
last_node[0] = child
|
170
|
-
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
171
|
-
|
172
|
-
def _split_node(self, key, child: TreeNode, split_len):
|
173
|
-
# new_node -> child
|
174
|
-
new_node = TreeNode()
|
175
|
-
new_node.children = {key[split_len:][0]: child}
|
176
|
-
new_node.parent = child.parent
|
177
|
-
new_node.lock_ref = child.lock_ref
|
178
|
-
new_node.key = child.key[:split_len]
|
179
|
-
new_node.value = child.value[:split_len]
|
180
|
-
child.parent = new_node
|
181
|
-
child.key = child.key[split_len:]
|
182
|
-
child.value = child.value[split_len:]
|
183
|
-
new_node.parent.children[key[:split_len][0]] = new_node
|
184
|
-
return new_node
|
185
|
-
|
186
|
-
def _insert_helper(self, node, key, value):
|
187
|
-
node.last_access_time = time.time()
|
188
|
-
if len(key) == 0:
|
189
|
-
return 0
|
190
|
-
|
191
|
-
if key[0] in node.children.keys():
|
192
|
-
child = node.children[key[0]]
|
193
|
-
prefix_len = _key_match(child.key, key)
|
194
|
-
|
195
|
-
if prefix_len == len(child.key):
|
196
|
-
if prefix_len == len(key):
|
197
|
-
return prefix_len
|
198
|
-
else:
|
199
|
-
key = key[prefix_len:]
|
200
|
-
value = value[prefix_len:]
|
201
|
-
return prefix_len + self._insert_helper(child, key, value)
|
202
|
-
|
203
|
-
new_node = self._split_node(child.key, child, prefix_len)
|
204
|
-
return prefix_len + self._insert_helper(
|
205
|
-
new_node, key[prefix_len:], value[prefix_len:]
|
206
|
-
)
|
207
|
-
|
208
|
-
if len(key):
|
209
|
-
new_node = TreeNode()
|
210
|
-
new_node.parent = node
|
211
|
-
new_node.key = key
|
212
|
-
new_node.value = value
|
213
|
-
node.children[key[0]] = new_node
|
214
|
-
self.evictable_size_ += len(value)
|
215
|
-
return 0
|
216
|
-
|
217
|
-
def _print_helper(self, node: TreeNode, indent):
|
218
|
-
for _, child in node.children.items():
|
219
|
-
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
220
|
-
self._print_helper(child, indent=indent + 2)
|
221
|
-
|
222
|
-
def _delete_leaf(self, node):
|
223
|
-
for k, v in node.parent.children.items():
|
224
|
-
if v == node:
|
225
|
-
break
|
226
|
-
del node.parent.children[k]
|
227
|
-
self.evictable_size_ -= len(node.key)
|
228
|
-
|
229
|
-
def _total_size_helper(self, node):
|
230
|
-
x = len(node.value)
|
231
|
-
for child in node.children.values():
|
232
|
-
x += self._total_size_helper(child)
|
233
|
-
return x
|
234
|
-
|
235
|
-
def _collect_leaves(self):
|
236
|
-
ret_list = []
|
237
|
-
|
238
|
-
def dfs_(cur_node):
|
239
|
-
if len(cur_node.children) == 0:
|
240
|
-
ret_list.append(cur_node)
|
241
|
-
|
242
|
-
for x in cur_node.children.values():
|
243
|
-
dfs_(x)
|
244
|
-
|
245
|
-
dfs_(self.root_node)
|
246
|
-
return ret_list
|
247
|
-
|
248
|
-
|
249
|
-
if __name__ == "__main__":
|
250
|
-
tree = RadixCache(None, None, False)
|
251
|
-
|
252
|
-
tree.insert("Hello")
|
253
|
-
tree.insert("Hello")
|
254
|
-
tree.insert("Hello_L.A.!")
|
255
|
-
# tree.insert("Hello_world! Happy")
|
256
|
-
# tree.insert("I love you!")
|
257
|
-
tree.pretty_print()
|
258
|
-
|
259
|
-
# print(tree.match_prefix("I love you! aha"))
|
260
|
-
|
261
|
-
# def evict_callback(x):
|
262
|
-
# print("evict", x)
|
263
|
-
# return len(x)
|
264
|
-
|
265
|
-
# tree.evict(5, evict_callback)
|
266
|
-
# tree.evict(10, evict_callback)
|
267
|
-
# tree.pretty_print()
|