sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,514 +0,0 @@
|
|
1
|
-
import importlib
|
2
|
-
import logging
|
3
|
-
import inspect
|
4
|
-
from dataclasses import dataclass
|
5
|
-
from functools import lru_cache
|
6
|
-
from pathlib import Path
|
7
|
-
import importlib.resources
|
8
|
-
|
9
|
-
import numpy as np
|
10
|
-
import torch
|
11
|
-
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
12
|
-
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
13
|
-
from sglang.srt.utils import is_multimodal_model
|
14
|
-
from sglang.utils import get_available_gpu_memory
|
15
|
-
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
16
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
17
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
18
|
-
from vllm.model_executor.model_loader import _set_default_torch_dtype
|
19
|
-
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
|
20
|
-
|
21
|
-
import importlib
|
22
|
-
import pkgutil
|
23
|
-
|
24
|
-
import sglang
|
25
|
-
|
26
|
-
QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
|
27
|
-
|
28
|
-
logger = logging.getLogger("model_runner")
|
29
|
-
|
30
|
-
|
31
|
-
# for server args in model endpoints
|
32
|
-
global_server_args_dict: dict = None
|
33
|
-
|
34
|
-
|
35
|
-
@lru_cache()
|
36
|
-
def import_model_classes():
|
37
|
-
model_arch_name_to_cls = {}
|
38
|
-
package_name = "sglang.srt.models"
|
39
|
-
package = importlib.import_module(package_name)
|
40
|
-
for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
|
41
|
-
if not ispkg:
|
42
|
-
module = importlib.import_module(name)
|
43
|
-
if hasattr(module, "EntryClass"):
|
44
|
-
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
45
|
-
return model_arch_name_to_cls
|
46
|
-
|
47
|
-
|
48
|
-
def get_model_cls_by_arch_name(model_arch_names):
|
49
|
-
model_arch_name_to_cls = import_model_classes()
|
50
|
-
|
51
|
-
model_class = None
|
52
|
-
for arch in model_arch_names:
|
53
|
-
if arch in model_arch_name_to_cls:
|
54
|
-
model_class = model_arch_name_to_cls[arch]
|
55
|
-
break
|
56
|
-
else:
|
57
|
-
raise ValueError(
|
58
|
-
f"Unsupported architectures: {arch}. "
|
59
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
60
|
-
)
|
61
|
-
return model_class
|
62
|
-
|
63
|
-
|
64
|
-
@dataclass
|
65
|
-
class InputMetadata:
|
66
|
-
model_runner: "ModelRunner"
|
67
|
-
forward_mode: ForwardMode
|
68
|
-
batch_size: int
|
69
|
-
total_num_tokens: int
|
70
|
-
max_seq_len: int
|
71
|
-
req_pool_indices: torch.Tensor
|
72
|
-
start_loc: torch.Tensor
|
73
|
-
seq_lens: torch.Tensor
|
74
|
-
prefix_lens: torch.Tensor
|
75
|
-
positions: torch.Tensor
|
76
|
-
req_to_token_pool: ReqToTokenPool
|
77
|
-
token_to_kv_pool: TokenToKVPool
|
78
|
-
|
79
|
-
# for extend
|
80
|
-
extend_seq_lens: torch.Tensor = None
|
81
|
-
extend_start_loc: torch.Tensor = None
|
82
|
-
max_extend_len: int = 0
|
83
|
-
|
84
|
-
out_cache_loc: torch.Tensor = None
|
85
|
-
out_cache_cont_start: torch.Tensor = None
|
86
|
-
out_cache_cont_end: torch.Tensor = None
|
87
|
-
|
88
|
-
other_kv_index: torch.Tensor = None
|
89
|
-
return_logprob: bool = False
|
90
|
-
|
91
|
-
# for flashinfer
|
92
|
-
qo_indptr: torch.Tensor = None
|
93
|
-
kv_indptr: torch.Tensor = None
|
94
|
-
kv_indices: torch.Tensor = None
|
95
|
-
kv_last_page_len: torch.Tensor = None
|
96
|
-
prefill_wrapper = None
|
97
|
-
decode_wrapper = None
|
98
|
-
|
99
|
-
def init_flashinfer_args(self, tp_size):
|
100
|
-
from flashinfer import (
|
101
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
102
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
103
|
-
)
|
104
|
-
|
105
|
-
self.kv_indptr = torch.zeros(
|
106
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
107
|
-
)
|
108
|
-
self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
|
109
|
-
self.kv_indices = torch.cat(
|
110
|
-
[
|
111
|
-
self.req_to_token_pool.req_to_token[
|
112
|
-
self.req_pool_indices[i].item(), : self.seq_lens[i].item()
|
113
|
-
]
|
114
|
-
for i in range(self.batch_size)
|
115
|
-
],
|
116
|
-
dim=0,
|
117
|
-
).contiguous()
|
118
|
-
self.kv_last_page_len = torch.ones(
|
119
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
120
|
-
)
|
121
|
-
|
122
|
-
workspace_buffer = torch.empty(
|
123
|
-
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
|
124
|
-
)
|
125
|
-
if (
|
126
|
-
self.forward_mode == ForwardMode.PREFILL
|
127
|
-
or self.forward_mode == ForwardMode.EXTEND
|
128
|
-
):
|
129
|
-
self.qo_indptr = torch.zeros(
|
130
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
131
|
-
)
|
132
|
-
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
133
|
-
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
134
|
-
workspace_buffer, "NHD"
|
135
|
-
)
|
136
|
-
args = [
|
137
|
-
self.qo_indptr,
|
138
|
-
self.kv_indptr,
|
139
|
-
self.kv_indices,
|
140
|
-
self.kv_last_page_len,
|
141
|
-
self.model_runner.model_config.num_attention_heads // tp_size,
|
142
|
-
self.model_runner.model_config.num_key_value_heads // tp_size,
|
143
|
-
]
|
144
|
-
|
145
|
-
# flashinfer >= 0.0.3
|
146
|
-
# FIXME: Drop this when flashinfer updates to 0.0.4
|
147
|
-
if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
|
148
|
-
args.append(self.model_runner.model_config.head_dim)
|
149
|
-
|
150
|
-
self.prefill_wrapper.begin_forward(*args)
|
151
|
-
else:
|
152
|
-
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
153
|
-
workspace_buffer, "NHD"
|
154
|
-
)
|
155
|
-
self.decode_wrapper.begin_forward(
|
156
|
-
self.kv_indptr,
|
157
|
-
self.kv_indices,
|
158
|
-
self.kv_last_page_len,
|
159
|
-
self.model_runner.model_config.num_attention_heads // tp_size,
|
160
|
-
self.model_runner.model_config.num_key_value_heads // tp_size,
|
161
|
-
self.model_runner.model_config.head_dim,
|
162
|
-
1,
|
163
|
-
"NONE",
|
164
|
-
"float16",
|
165
|
-
)
|
166
|
-
|
167
|
-
def init_extend_args(self):
|
168
|
-
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
169
|
-
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
170
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
171
|
-
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
172
|
-
|
173
|
-
@classmethod
|
174
|
-
def create(
|
175
|
-
cls,
|
176
|
-
model_runner,
|
177
|
-
tp_size,
|
178
|
-
forward_mode,
|
179
|
-
req_pool_indices,
|
180
|
-
seq_lens,
|
181
|
-
prefix_lens,
|
182
|
-
position_ids_offsets,
|
183
|
-
out_cache_loc,
|
184
|
-
out_cache_cont_start=None,
|
185
|
-
out_cache_cont_end=None,
|
186
|
-
return_logprob=False,
|
187
|
-
):
|
188
|
-
batch_size = len(req_pool_indices)
|
189
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
190
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
191
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
192
|
-
max_seq_len = int(torch.max(seq_lens))
|
193
|
-
|
194
|
-
if forward_mode == ForwardMode.DECODE:
|
195
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
196
|
-
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
197
|
-
req_pool_indices[0], seq_lens[0] - 1
|
198
|
-
].item()
|
199
|
-
else:
|
200
|
-
seq_lens_np = seq_lens.cpu().numpy()
|
201
|
-
prefix_lens_np = prefix_lens.cpu().numpy()
|
202
|
-
position_ids_offsets_np = position_ids_offsets.cpu().numpy()
|
203
|
-
positions = torch.tensor(
|
204
|
-
np.concatenate(
|
205
|
-
[
|
206
|
-
np.arange(
|
207
|
-
prefix_lens_np[i] + position_ids_offsets_np[i],
|
208
|
-
seq_lens_np[i] + position_ids_offsets_np[i],
|
209
|
-
)
|
210
|
-
for i in range(batch_size)
|
211
|
-
],
|
212
|
-
axis=0,
|
213
|
-
),
|
214
|
-
device="cuda",
|
215
|
-
)
|
216
|
-
other_kv_index = None
|
217
|
-
|
218
|
-
ret = cls(
|
219
|
-
model_runner=model_runner,
|
220
|
-
forward_mode=forward_mode,
|
221
|
-
batch_size=batch_size,
|
222
|
-
total_num_tokens=total_num_tokens,
|
223
|
-
max_seq_len=max_seq_len,
|
224
|
-
req_pool_indices=req_pool_indices,
|
225
|
-
start_loc=start_loc,
|
226
|
-
seq_lens=seq_lens,
|
227
|
-
prefix_lens=prefix_lens,
|
228
|
-
positions=positions,
|
229
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
230
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
231
|
-
out_cache_loc=out_cache_loc,
|
232
|
-
out_cache_cont_start=out_cache_cont_start,
|
233
|
-
out_cache_cont_end=out_cache_cont_end,
|
234
|
-
return_logprob=return_logprob,
|
235
|
-
other_kv_index=other_kv_index,
|
236
|
-
)
|
237
|
-
|
238
|
-
if forward_mode == ForwardMode.EXTEND:
|
239
|
-
ret.init_extend_args()
|
240
|
-
|
241
|
-
if global_server_args_dict.get("enable_flashinfer", False):
|
242
|
-
ret.init_flashinfer_args(tp_size)
|
243
|
-
|
244
|
-
return ret
|
245
|
-
|
246
|
-
|
247
|
-
class ModelRunner:
|
248
|
-
def __init__(
|
249
|
-
self,
|
250
|
-
model_config,
|
251
|
-
mem_fraction_static,
|
252
|
-
tp_rank,
|
253
|
-
tp_size,
|
254
|
-
nccl_port,
|
255
|
-
load_format="auto",
|
256
|
-
trust_remote_code=True,
|
257
|
-
server_args_dict: dict = {},
|
258
|
-
):
|
259
|
-
self.model_config = model_config
|
260
|
-
self.mem_fraction_static = mem_fraction_static
|
261
|
-
self.tp_rank = tp_rank
|
262
|
-
self.tp_size = tp_size
|
263
|
-
self.nccl_port = nccl_port
|
264
|
-
self.load_format = load_format
|
265
|
-
self.trust_remote_code = trust_remote_code
|
266
|
-
|
267
|
-
global global_server_args_dict
|
268
|
-
global_server_args_dict = server_args_dict
|
269
|
-
|
270
|
-
# Init torch distributed
|
271
|
-
torch.cuda.set_device(self.tp_rank)
|
272
|
-
torch.distributed.init_process_group(
|
273
|
-
backend="nccl",
|
274
|
-
world_size=self.tp_size,
|
275
|
-
rank=self.tp_rank,
|
276
|
-
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
277
|
-
)
|
278
|
-
|
279
|
-
# A small all_reduce for warmup.
|
280
|
-
if self.tp_size > 1:
|
281
|
-
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
282
|
-
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
283
|
-
|
284
|
-
total_gpu_memory = get_available_gpu_memory(
|
285
|
-
self.tp_rank, distributed=self.tp_size > 1
|
286
|
-
) * (1 << 30)
|
287
|
-
self.load_model()
|
288
|
-
self.init_memory_pool(total_gpu_memory)
|
289
|
-
|
290
|
-
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
291
|
-
|
292
|
-
def load_model(self):
|
293
|
-
"""See also vllm/model_executor/model_loader.py::get_model"""
|
294
|
-
# Select model class
|
295
|
-
architectures = getattr(self.model_config.hf_config, "architectures", [])
|
296
|
-
model_class = get_model_cls_by_arch_name(architectures)
|
297
|
-
logger.info(f"Rank {self.tp_rank}: load weight begin.")
|
298
|
-
|
299
|
-
# Load weights
|
300
|
-
linear_method = None
|
301
|
-
with _set_default_torch_dtype(torch.float16):
|
302
|
-
with torch.device("cuda"):
|
303
|
-
hf_quant_config = getattr(
|
304
|
-
self.model_config.hf_config, "quantization_config", None
|
305
|
-
)
|
306
|
-
if hf_quant_config is not None:
|
307
|
-
hf_quant_method = hf_quant_config["quant_method"]
|
308
|
-
|
309
|
-
# compat: autogptq uses is_marlin_format within quant config
|
310
|
-
if (hf_quant_method == "gptq"
|
311
|
-
and "is_marlin_format" in hf_quant_config
|
312
|
-
and hf_quant_config["is_marlin_format"]):
|
313
|
-
hf_quant_method = "marlin"
|
314
|
-
quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
|
315
|
-
|
316
|
-
if quant_config_class is None:
|
317
|
-
raise ValueError(
|
318
|
-
f"Unsupported quantization method: {hf_quant_config['quant_method']}"
|
319
|
-
)
|
320
|
-
quant_config = quant_config_class.from_config(hf_quant_config)
|
321
|
-
logger.info(f"quant_config: {quant_config}")
|
322
|
-
linear_method = quant_config.get_linear_method()
|
323
|
-
model = model_class(
|
324
|
-
config=self.model_config.hf_config, linear_method=linear_method
|
325
|
-
)
|
326
|
-
model.load_weights(
|
327
|
-
self.model_config.path,
|
328
|
-
cache_dir=None,
|
329
|
-
load_format=self.load_format,
|
330
|
-
revision=None,
|
331
|
-
)
|
332
|
-
self.model = model.eval()
|
333
|
-
|
334
|
-
logger.info(f"Rank {self.tp_rank}: load weight end.")
|
335
|
-
|
336
|
-
def profile_max_num_token(self, total_gpu_memory):
|
337
|
-
available_gpu_memory = get_available_gpu_memory(
|
338
|
-
self.tp_rank, distributed=self.tp_size > 1
|
339
|
-
) * (1 << 30)
|
340
|
-
head_dim = self.model_config.head_dim
|
341
|
-
head_num = self.model_config.num_key_value_heads // self.tp_size
|
342
|
-
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
343
|
-
rest_memory = available_gpu_memory - total_gpu_memory * (
|
344
|
-
1 - self.mem_fraction_static
|
345
|
-
)
|
346
|
-
max_num_token = int(rest_memory // cell_size)
|
347
|
-
return max_num_token
|
348
|
-
|
349
|
-
def init_memory_pool(self, total_gpu_memory):
|
350
|
-
self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
|
351
|
-
|
352
|
-
if self.max_total_num_token <= 0:
|
353
|
-
raise RuntimeError(
|
354
|
-
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
355
|
-
)
|
356
|
-
|
357
|
-
self.req_to_token_pool = ReqToTokenPool(
|
358
|
-
int(self.max_total_num_token / self.model_config.context_len * 256),
|
359
|
-
self.model_config.context_len + 8,
|
360
|
-
)
|
361
|
-
self.token_to_kv_pool = TokenToKVPool(
|
362
|
-
self.max_total_num_token,
|
363
|
-
dtype=torch.float16,
|
364
|
-
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
365
|
-
head_dim=self.model_config.head_dim,
|
366
|
-
layer_num=self.model_config.num_hidden_layers,
|
367
|
-
)
|
368
|
-
|
369
|
-
@torch.inference_mode()
|
370
|
-
def forward_prefill(
|
371
|
-
self,
|
372
|
-
input_ids,
|
373
|
-
req_pool_indices,
|
374
|
-
seq_lens,
|
375
|
-
prefix_lens,
|
376
|
-
position_ids_offsets,
|
377
|
-
out_cache_loc,
|
378
|
-
return_logprob,
|
379
|
-
):
|
380
|
-
input_metadata = InputMetadata.create(
|
381
|
-
self,
|
382
|
-
forward_mode=ForwardMode.PREFILL,
|
383
|
-
tp_size=self.tp_size,
|
384
|
-
req_pool_indices=req_pool_indices,
|
385
|
-
seq_lens=seq_lens,
|
386
|
-
prefix_lens=prefix_lens,
|
387
|
-
position_ids_offsets=position_ids_offsets,
|
388
|
-
out_cache_loc=out_cache_loc,
|
389
|
-
return_logprob=return_logprob,
|
390
|
-
)
|
391
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
392
|
-
|
393
|
-
@torch.inference_mode()
|
394
|
-
def forward_extend(
|
395
|
-
self,
|
396
|
-
input_ids,
|
397
|
-
req_pool_indices,
|
398
|
-
seq_lens,
|
399
|
-
prefix_lens,
|
400
|
-
position_ids_offsets,
|
401
|
-
out_cache_loc,
|
402
|
-
return_logprob,
|
403
|
-
):
|
404
|
-
input_metadata = InputMetadata.create(
|
405
|
-
self,
|
406
|
-
forward_mode=ForwardMode.EXTEND,
|
407
|
-
tp_size=self.tp_size,
|
408
|
-
req_pool_indices=req_pool_indices,
|
409
|
-
seq_lens=seq_lens,
|
410
|
-
prefix_lens=prefix_lens,
|
411
|
-
position_ids_offsets=position_ids_offsets,
|
412
|
-
out_cache_loc=out_cache_loc,
|
413
|
-
return_logprob=return_logprob,
|
414
|
-
)
|
415
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
416
|
-
|
417
|
-
@torch.inference_mode()
|
418
|
-
def forward_decode(
|
419
|
-
self,
|
420
|
-
input_ids,
|
421
|
-
req_pool_indices,
|
422
|
-
seq_lens,
|
423
|
-
prefix_lens,
|
424
|
-
position_ids_offsets,
|
425
|
-
out_cache_loc,
|
426
|
-
out_cache_cont_start,
|
427
|
-
out_cache_cont_end,
|
428
|
-
return_logprob,
|
429
|
-
):
|
430
|
-
input_metadata = InputMetadata.create(
|
431
|
-
self,
|
432
|
-
forward_mode=ForwardMode.DECODE,
|
433
|
-
tp_size=self.tp_size,
|
434
|
-
req_pool_indices=req_pool_indices,
|
435
|
-
seq_lens=seq_lens,
|
436
|
-
prefix_lens=prefix_lens,
|
437
|
-
position_ids_offsets=position_ids_offsets,
|
438
|
-
out_cache_loc=out_cache_loc,
|
439
|
-
out_cache_cont_start=out_cache_cont_start,
|
440
|
-
out_cache_cont_end=out_cache_cont_end,
|
441
|
-
return_logprob=return_logprob,
|
442
|
-
)
|
443
|
-
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
444
|
-
|
445
|
-
@torch.inference_mode()
|
446
|
-
def forward_extend_multi_modal(
|
447
|
-
self,
|
448
|
-
input_ids,
|
449
|
-
pixel_values,
|
450
|
-
image_sizes,
|
451
|
-
image_offsets,
|
452
|
-
req_pool_indices,
|
453
|
-
seq_lens,
|
454
|
-
prefix_lens,
|
455
|
-
position_ids_offsets,
|
456
|
-
out_cache_loc,
|
457
|
-
return_logprob,
|
458
|
-
):
|
459
|
-
input_metadata = InputMetadata.create(
|
460
|
-
self,
|
461
|
-
forward_mode=ForwardMode.EXTEND,
|
462
|
-
tp_size=self.tp_size,
|
463
|
-
req_pool_indices=req_pool_indices,
|
464
|
-
seq_lens=seq_lens,
|
465
|
-
prefix_lens=prefix_lens,
|
466
|
-
position_ids_offsets=position_ids_offsets,
|
467
|
-
out_cache_loc=out_cache_loc,
|
468
|
-
return_logprob=return_logprob,
|
469
|
-
)
|
470
|
-
return self.model.forward(
|
471
|
-
input_ids,
|
472
|
-
input_metadata.positions,
|
473
|
-
input_metadata,
|
474
|
-
pixel_values,
|
475
|
-
image_sizes,
|
476
|
-
image_offsets,
|
477
|
-
)
|
478
|
-
|
479
|
-
def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
|
480
|
-
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
481
|
-
kwargs = {
|
482
|
-
"input_ids": batch.input_ids,
|
483
|
-
"pixel_values": batch.pixel_values,
|
484
|
-
"image_sizes": batch.image_sizes,
|
485
|
-
"image_offsets": batch.image_offsets,
|
486
|
-
"req_pool_indices": batch.req_pool_indices,
|
487
|
-
"seq_lens": batch.seq_lens,
|
488
|
-
"prefix_lens": batch.prefix_lens,
|
489
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
490
|
-
"out_cache_loc": batch.out_cache_loc,
|
491
|
-
"return_logprob": return_logprob,
|
492
|
-
}
|
493
|
-
return self.forward_extend_multi_modal(**kwargs)
|
494
|
-
else:
|
495
|
-
kwargs = {
|
496
|
-
"input_ids": batch.input_ids,
|
497
|
-
"req_pool_indices": batch.req_pool_indices,
|
498
|
-
"seq_lens": batch.seq_lens,
|
499
|
-
"prefix_lens": batch.prefix_lens,
|
500
|
-
"position_ids_offsets": batch.position_ids_offsets,
|
501
|
-
"out_cache_loc": batch.out_cache_loc,
|
502
|
-
"return_logprob": return_logprob,
|
503
|
-
}
|
504
|
-
|
505
|
-
if forward_mode == ForwardMode.DECODE:
|
506
|
-
kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
|
507
|
-
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
508
|
-
return self.forward_decode(**kwargs)
|
509
|
-
elif forward_mode == ForwardMode.EXTEND:
|
510
|
-
return self.forward_extend(**kwargs)
|
511
|
-
elif forward_mode == ForwardMode.PREFILL:
|
512
|
-
return self.forward_prefill(**kwargs)
|
513
|
-
else:
|
514
|
-
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
@@ -1,70 +0,0 @@
|
|
1
|
-
import random
|
2
|
-
from collections import defaultdict
|
3
|
-
|
4
|
-
|
5
|
-
class Scheduler:
|
6
|
-
def __init__(
|
7
|
-
self,
|
8
|
-
schedule_heuristic,
|
9
|
-
max_running_seq,
|
10
|
-
max_prefill_num_token,
|
11
|
-
max_total_num_token,
|
12
|
-
tree_cache,
|
13
|
-
):
|
14
|
-
self.schedule_heuristic = schedule_heuristic
|
15
|
-
self.max_running_seq = max_running_seq
|
16
|
-
self.max_prefill_num_token = max_prefill_num_token
|
17
|
-
self.max_total_num_token = max_total_num_token
|
18
|
-
self.tree_cache = tree_cache
|
19
|
-
|
20
|
-
def get_priority_queue(self, forward_queue):
|
21
|
-
if self.schedule_heuristic == "lpm":
|
22
|
-
# longest prefix match
|
23
|
-
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
24
|
-
return forward_queue
|
25
|
-
elif self.schedule_heuristic == "random":
|
26
|
-
random.shuffle(forward_queue)
|
27
|
-
return forward_queue
|
28
|
-
elif self.schedule_heuristic == "fcfs":
|
29
|
-
return forward_queue
|
30
|
-
elif self.schedule_heuristic == "weight":
|
31
|
-
last_node_to_reqs = defaultdict(list)
|
32
|
-
for req in forward_queue:
|
33
|
-
last_node_to_reqs[req.last_node].append(req)
|
34
|
-
for node in last_node_to_reqs:
|
35
|
-
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
36
|
-
|
37
|
-
node_to_weight = defaultdict(int)
|
38
|
-
self._calc_weight_recursive(
|
39
|
-
self.tree_cache.root_node, last_node_to_reqs, node_to_weight
|
40
|
-
)
|
41
|
-
|
42
|
-
tmp_queue = []
|
43
|
-
self._get_weight_priority_recursive(
|
44
|
-
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
|
45
|
-
)
|
46
|
-
assert len(tmp_queue) == len(forward_queue)
|
47
|
-
return tmp_queue
|
48
|
-
else:
|
49
|
-
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
50
|
-
|
51
|
-
def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
|
52
|
-
node_to_weight[cur_node] = 1
|
53
|
-
if cur_node in last_node_to_reqs:
|
54
|
-
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
55
|
-
for child in cur_node.children.values():
|
56
|
-
self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
|
57
|
-
node_to_weight[cur_node] += node_to_weight[child]
|
58
|
-
|
59
|
-
def _get_weight_priority_recursive(
|
60
|
-
self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
|
61
|
-
):
|
62
|
-
visit_list = [child for child in cur_node.children.values()]
|
63
|
-
visit_list.sort(key=lambda x: -node_to_wight[x])
|
64
|
-
# for node in visit_list:
|
65
|
-
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
66
|
-
for child in visit_list:
|
67
|
-
self._get_weight_priority_recursive(
|
68
|
-
child, node_to_wight, last_node_to_reqs, tmp_queue
|
69
|
-
)
|
70
|
-
tmp_queue.extend(last_node_to_reqs[cur_node])
|
sglang-0.1.14.dist-info/RECORD
DELETED
@@ -1,64 +0,0 @@
|
|
1
|
-
sglang/__init__.py,sha256=Nxa2M7XCh2-e6I7VrCg7OSBL6BvEW3gyRD14ZdykpRM,96
|
2
|
-
sglang/api.py,sha256=0-Eh7c41hWKjPXrzzvLFdLAUVkvmPGJGLAsrG9evDTE,4576
|
3
|
-
sglang/global_config.py,sha256=PAX7TWeFcq0HBzNUWyCONAOjqIokWqw8vT7I6sBSKTc,797
|
4
|
-
sglang/launch_server.py,sha256=jKPZRDN5bUe8Wgz5eoDkqeePhmKa8DLD4DpXQLT5auo,294
|
5
|
-
sglang/utils.py,sha256=2dUXLMPz9VhhzbIRQABmfZnVW5yz61F3UVtb6yKyevM,6237
|
6
|
-
sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
-
sglang/backend/anthropic.py,sha256=GJ_T1Jg0VOtajgkgczPKt5sjuVYdbAiWd2jXlJRNRmg,1677
|
8
|
-
sglang/backend/base_backend.py,sha256=APiMht4WYECLCOGRPCEUF6lX-an1vjVe2dWoMSgymWY,1831
|
9
|
-
sglang/backend/openai.py,sha256=nPdA88A5GISJTH88svJdww3qHWIHZcGG2NEn0XjMkLU,9578
|
10
|
-
sglang/backend/runtime_endpoint.py,sha256=r7dTazselaudlFx8hqk-PQLYDHZhpbAKjyFF1zLuM_E,8022
|
11
|
-
sglang/backend/vertexai.py,sha256=BLfWf_tEgoHY9srCufJM5PLe3tql2j0G6ia7cPykxCM,4713
|
12
|
-
sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
|
-
sglang/lang/chat_template.py,sha256=MaCF0fvNky0nJC9OvmAeApeHYgM6Lr03mtRhF0lS31U,8000
|
14
|
-
sglang/lang/compiler.py,sha256=wNn_UqV6Sxl22mv-PpzFUtRgiFFV-Y4OYpO4LshEoRM,7527
|
15
|
-
sglang/lang/interpreter.py,sha256=ahRxuEJZ7b1Tts2Lr7wViWIqL-Z12T3anvgj0XdvMN8,26666
|
16
|
-
sglang/lang/ir.py,sha256=8Ap-uEUz6K9eNQTOKtMixePuLwRFHFKcN0Z5Yn44nKk,13320
|
17
|
-
sglang/lang/tracer.py,sha256=pFiSNzPSg0l7ZZIlGqJDLCmQALR-wyo2dFgJP73J4_Y,8260
|
18
|
-
sglang/srt/backend_config.py,sha256=UIV6kIU2j-Xh0eoezn1aXcYIy0miftHsWFeAZwqpbGE,227
|
19
|
-
sglang/srt/conversation.py,sha256=mTstD-SsXG5p_YhWQUPEWU-vzzDMF4RgQ7KmLkOOC7U,15496
|
20
|
-
sglang/srt/hf_transformers_utils.py,sha256=soRyYLoCn7GxgxvonufGFkdFBA3eH5i3Izk_wi7p1l0,5285
|
21
|
-
sglang/srt/memory_pool.py,sha256=BMoX2wvicj214mV-xvcr_Iv_Je0qs3zTuzXfQVpV8u4,3609
|
22
|
-
sglang/srt/mm_utils.py,sha256=OptgAHDX-73Bk4jAdr2BOAJtiEXJNzPrMhaM-dy275c,8889
|
23
|
-
sglang/srt/model_config.py,sha256=ned-odjmKBKBhVPo04FEpus9gJsUWxrFLrLxahLwSaw,1328
|
24
|
-
sglang/srt/sampling_params.py,sha256=83Fp-4HWThC20TEh139XcIb_erBqfI7KZg5txdRBq7c,2896
|
25
|
-
sglang/srt/server.py,sha256=WLXissKuXQI7JFb2V8D47QSF-PPHnW-JZCiQm4YW0xE,24070
|
26
|
-
sglang/srt/server_args.py,sha256=bvbi-Rb_JudqztFFfRsuXBYtUsG9hq4zMFt7X97uDhA,8954
|
27
|
-
sglang/srt/utils.py,sha256=IEqpmWx_hl4eXn_KoHM0EPXmxeN2wKkgK7H01_t0x5Q,7355
|
28
|
-
sglang/srt/constrained/__init__.py,sha256=BPRNDJnWtzYJ13X4urRS5aE6wFuwAVNBA9qeWIHF8rE,1236
|
29
|
-
sglang/srt/constrained/base_cache.py,sha256=QQjmFEiT8jlOskJoZobhrDl2TKB-B4b1LPQo9JQCP_w,1405
|
30
|
-
sglang/srt/constrained/fsm_cache.py,sha256=20mEgtDXU1Zeoicl5KBQC3arkg-RhRWiYnchJc00m1g,901
|
31
|
-
sglang/srt/constrained/jump_forward.py,sha256=Z-pz2Jnvk1CxSEZA65OVq0GryqdiKuOkhhc13v5T6Lo,2482
|
32
|
-
sglang/srt/layers/context_flashattention_nopad.py,sha256=TVYQ6IjftWVXORmKpEROMqQxDOnF6n2g0G1Ci4LquYM,5209
|
33
|
-
sglang/srt/layers/extend_attention.py,sha256=KGqQOA5mel9qScXMAQP_3Qyhp3BNbiQ7Y_6wi38Lxcs,12622
|
34
|
-
sglang/srt/layers/logits_processor.py,sha256=MW2bpqSXyghODMojqeMSYWZhUHuAFPk_gUkyyLw9HkM,4827
|
35
|
-
sglang/srt/layers/radix_attention.py,sha256=bqrb8H8K8RbKTr1PzVmpnUxRzMj0H-OWCi1JYZKuRDw,5597
|
36
|
-
sglang/srt/layers/token_attention.py,sha256=waOjGsWZlvf6epFhYerRJlAaMwvDTy_Z3uzPaXsVQUU,8516
|
37
|
-
sglang/srt/managers/detokenizer_manager.py,sha256=1lPNh_Pe6Pr0v-TzlCBBREbvz4uFWxyw31SmnEZh0s8,3292
|
38
|
-
sglang/srt/managers/io_struct.py,sha256=nXJh3CrOvv9MdAfIFoo6SCXuNQTG3KswmRKkwF61Tek,3141
|
39
|
-
sglang/srt/managers/openai_protocol.py,sha256=cttqg9iv3de8fhtCqDI4cYoPPZ_gULedMXstV1ok6WA,4563
|
40
|
-
sglang/srt/managers/tokenizer_manager.py,sha256=hgsR9AMj6ic9S3-2WiELh7Hnp8Xnb_bzp7kpbjHwHtM,9733
|
41
|
-
sglang/srt/managers/router/infer_batch.py,sha256=U-Ckt9ad1WaOQF_dW6Eo9AMIRQoOJQ-Pm-MMXnEmPP8,18399
|
42
|
-
sglang/srt/managers/router/manager.py,sha256=TNYs0IrkZGkPvZJViwL7BMUg0VlvzeyTjDMjuvRoMDI,2529
|
43
|
-
sglang/srt/managers/router/model_rpc.py,sha256=VlwLNpHZ92bnteQl4PhVKoAXM0C8Y4_2LBBVaffeu3g,26766
|
44
|
-
sglang/srt/managers/router/model_runner.py,sha256=-wWv00EbB_UkkLpio6VKGBTagfzxLHfY-eKDDQ0rZQc,18292
|
45
|
-
sglang/srt/managers/router/radix_cache.py,sha256=XGUF5mxQTSCzD7GW_ltNP2p5aelEKrMXzdezufJ7NCQ,6484
|
46
|
-
sglang/srt/managers/router/scheduler.py,sha256=V-LAnVSzgD2ddy2eXW3jWURCeq9Lv7YxCGk4kHyytfM,2818
|
47
|
-
sglang/srt/models/gemma.py,sha256=8XlfHPtVixPYYjz5F9T4DOAuoordWFStmyFFWGfny1k,11582
|
48
|
-
sglang/srt/models/llama2.py,sha256=VL4iN8R3wyTNr0bDxxKdLNnVGEvdXF6iGvA768YeakA,11611
|
49
|
-
sglang/srt/models/llava.py,sha256=42sn-AgI-6dMaTEU4aEbi4Js5epy0J3JVQoMooUOKt8,14922
|
50
|
-
sglang/srt/models/mistral.py,sha256=XSn7fiZqspyWVTYrpVAacAnWdwAybBtyn9-Sh9AvMTM,254
|
51
|
-
sglang/srt/models/mixtral.py,sha256=wqIwKfR90ih0gDiTZkFZcQD4PIYpZFD3CmzxRcuKIqw,13915
|
52
|
-
sglang/srt/models/qwen.py,sha256=CvdbcF90aI1tJPSQ-3OMUaQGMuaxCGe0y29m5nU_Yj0,9225
|
53
|
-
sglang/srt/models/qwen2.py,sha256=myPc0wvgf5ZzJyGhUGN49YjY-tMf4t8Jn_Imjg8D7Mk,11307
|
54
|
-
sglang/srt/models/stablelm.py,sha256=vMZUNgwXKPGYr5FcdYHw5g3QifVu9owKqq51_-EBOY0,10817
|
55
|
-
sglang/srt/models/yivl.py,sha256=Qvp-zQ93cOZGg3zVyaiQLhRsfXiLrQhxu9TyQP2FMm4,4414
|
56
|
-
sglang/test/test_conversation.py,sha256=1zIrXcXiwEliPHgDAsqsQUA7JKzZ5fnQEU-U6L887FU,1592
|
57
|
-
sglang/test/test_openai_protocol.py,sha256=eePzoskYR3PqfWczSVZvg8ja63qbT8TFUNEMyzDZpa8,1657
|
58
|
-
sglang/test/test_programs.py,sha256=mrLhGuprwvx8ZJ-0Qe28E-iCw5Qv-9T0SAv1Jgo1AJw,11421
|
59
|
-
sglang/test/test_utils.py,sha256=6PhTRi8UnR-BRNjit6aGu0M5lO0RebNQwEcDt712hE4,4830
|
60
|
-
sglang-0.1.14.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
61
|
-
sglang-0.1.14.dist-info/METADATA,sha256=C5N0VOYRHixdJcsf4dExIvP-Q099kYBMKs_dA4LBXSM,28809
|
62
|
-
sglang-0.1.14.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
63
|
-
sglang-0.1.14.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
|
64
|
-
sglang-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|