sglang 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +26 -0
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +40 -18
- sglang/global_config.py +21 -16
- sglang/lang/chat_template.py +41 -6
- sglang/lang/interpreter.py +5 -1
- sglang/lang/ir.py +61 -25
- sglang/srt/constrained/__init__.py +3 -2
- sglang/srt/hf_transformers_utils.py +7 -3
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +181 -167
- sglang/srt/layers/logits_processor.py +55 -19
- sglang/srt/layers/radix_attention.py +33 -59
- sglang/srt/layers/token_attention.py +4 -8
- sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
- sglang/srt/managers/controller/infer_batch.py +244 -36
- sglang/srt/managers/controller/manager_single.py +1 -1
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/tp_worker.py +39 -20
- sglang/srt/managers/detokenizer_manager.py +4 -2
- sglang/srt/managers/io_struct.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/memory_pool.py +33 -6
- sglang/srt/model_config.py +6 -0
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/llama2.py +3 -3
- sglang/srt/models/llama_classification.py +10 -7
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/openai_api_adapter.py +2 -2
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +18 -8
- sglang/srt/server_args.py +24 -20
- sglang/srt/utils.py +68 -35
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/METADATA +19 -13
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/RECORD +40 -36
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
- {sglang-0.1.18.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,9 @@ import importlib
|
|
4
4
|
import importlib.resources
|
5
5
|
import logging
|
6
6
|
import pkgutil
|
7
|
-
from dataclasses import dataclass
|
8
7
|
from functools import lru_cache
|
9
|
-
from typing import
|
8
|
+
from typing import Optional, Type
|
10
9
|
|
11
|
-
import numpy as np
|
12
10
|
import torch
|
13
11
|
import torch.nn as nn
|
14
12
|
from vllm.config import DeviceConfig, LoadConfig
|
@@ -17,7 +15,8 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
|
|
17
15
|
from vllm.model_executor.model_loader import get_model
|
18
16
|
from vllm.model_executor.models import ModelRegistry
|
19
17
|
|
20
|
-
from sglang.
|
18
|
+
from sglang.global_config import global_config
|
19
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
21
20
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
22
21
|
from sglang.srt.server_args import ServerArgs
|
23
22
|
from sglang.srt.utils import (
|
@@ -29,210 +28,6 @@ from sglang.srt.utils import (
|
|
29
28
|
|
30
29
|
logger = logging.getLogger("srt.model_runner")
|
31
30
|
|
32
|
-
# for server args in model endpoints
|
33
|
-
global_server_args_dict = {}
|
34
|
-
|
35
|
-
|
36
|
-
@dataclass
|
37
|
-
class InputMetadata:
|
38
|
-
forward_mode: ForwardMode
|
39
|
-
batch_size: int
|
40
|
-
total_num_tokens: int
|
41
|
-
max_seq_len: int
|
42
|
-
req_pool_indices: torch.Tensor
|
43
|
-
start_loc: torch.Tensor
|
44
|
-
seq_lens: torch.Tensor
|
45
|
-
prefix_lens: torch.Tensor
|
46
|
-
positions: torch.Tensor
|
47
|
-
req_to_token_pool: ReqToTokenPool
|
48
|
-
token_to_kv_pool: TokenToKVPool
|
49
|
-
|
50
|
-
# for extend
|
51
|
-
extend_seq_lens: torch.Tensor = None
|
52
|
-
extend_start_loc: torch.Tensor = None
|
53
|
-
max_extend_len: int = 0
|
54
|
-
|
55
|
-
out_cache_loc: torch.Tensor = None
|
56
|
-
out_cache_cont_start: torch.Tensor = None
|
57
|
-
out_cache_cont_end: torch.Tensor = None
|
58
|
-
|
59
|
-
other_kv_index: torch.Tensor = None
|
60
|
-
return_logprob: bool = False
|
61
|
-
top_logprobs_nums: List[int] = None
|
62
|
-
|
63
|
-
# for flashinfer
|
64
|
-
qo_indptr: torch.Tensor = None
|
65
|
-
kv_indptr: torch.Tensor = None
|
66
|
-
kv_indices: torch.Tensor = None
|
67
|
-
kv_last_page_len: torch.Tensor = None
|
68
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
|
-
|
72
|
-
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
73
|
-
if (
|
74
|
-
self.forward_mode == ForwardMode.PREFILL
|
75
|
-
or self.forward_mode == ForwardMode.EXTEND
|
76
|
-
):
|
77
|
-
paged_kernel_lens = self.prefix_lens
|
78
|
-
self.no_prefix = torch.all(self.prefix_lens == 0)
|
79
|
-
else:
|
80
|
-
paged_kernel_lens = self.seq_lens
|
81
|
-
|
82
|
-
self.kv_indptr = torch.zeros(
|
83
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
84
|
-
)
|
85
|
-
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
86
|
-
self.kv_last_page_len = torch.ones(
|
87
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
88
|
-
)
|
89
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
90
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
91
|
-
self.kv_indices = torch.cat(
|
92
|
-
[
|
93
|
-
self.req_to_token_pool.req_to_token[
|
94
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
95
|
-
]
|
96
|
-
for i in range(self.batch_size)
|
97
|
-
],
|
98
|
-
dim=0,
|
99
|
-
).contiguous()
|
100
|
-
|
101
|
-
if (
|
102
|
-
self.forward_mode == ForwardMode.PREFILL
|
103
|
-
or self.forward_mode == ForwardMode.EXTEND
|
104
|
-
):
|
105
|
-
# extend part
|
106
|
-
self.qo_indptr = torch.zeros(
|
107
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
108
|
-
)
|
109
|
-
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
110
|
-
|
111
|
-
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
112
|
-
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
113
|
-
self.qo_indptr,
|
114
|
-
self.qo_indptr.clone(),
|
115
|
-
num_qo_heads,
|
116
|
-
num_kv_heads,
|
117
|
-
head_dim,
|
118
|
-
)
|
119
|
-
|
120
|
-
# cached part
|
121
|
-
self.flashinfer_prefill_wrapper_paged.end_forward()
|
122
|
-
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
123
|
-
self.qo_indptr,
|
124
|
-
self.kv_indptr,
|
125
|
-
self.kv_indices,
|
126
|
-
self.kv_last_page_len,
|
127
|
-
num_qo_heads,
|
128
|
-
num_kv_heads,
|
129
|
-
head_dim,
|
130
|
-
1
|
131
|
-
)
|
132
|
-
else:
|
133
|
-
self.flashinfer_decode_wrapper.end_forward()
|
134
|
-
self.flashinfer_decode_wrapper.begin_forward(
|
135
|
-
self.kv_indptr,
|
136
|
-
self.kv_indices,
|
137
|
-
self.kv_last_page_len,
|
138
|
-
num_qo_heads,
|
139
|
-
num_kv_heads,
|
140
|
-
head_dim,
|
141
|
-
1,
|
142
|
-
pos_encoding_mode="NONE",
|
143
|
-
data_type=self.token_to_kv_pool.kv_data[0].dtype
|
144
|
-
)
|
145
|
-
|
146
|
-
def init_extend_args(self):
|
147
|
-
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
148
|
-
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
149
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
150
|
-
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
151
|
-
|
152
|
-
@classmethod
|
153
|
-
def create(
|
154
|
-
cls,
|
155
|
-
model_runner,
|
156
|
-
tp_size,
|
157
|
-
forward_mode,
|
158
|
-
req_pool_indices,
|
159
|
-
seq_lens,
|
160
|
-
prefix_lens,
|
161
|
-
position_ids_offsets,
|
162
|
-
out_cache_loc,
|
163
|
-
out_cache_cont_start=None,
|
164
|
-
out_cache_cont_end=None,
|
165
|
-
top_logprobs_nums=None,
|
166
|
-
return_logprob=False,
|
167
|
-
flashinfer_prefill_wrapper_ragged=None,
|
168
|
-
flashinfer_prefill_wrapper_paged=None,
|
169
|
-
flashinfer_decode_wrapper=None,
|
170
|
-
):
|
171
|
-
batch_size = len(req_pool_indices)
|
172
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
173
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
174
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
175
|
-
max_seq_len = int(torch.max(seq_lens))
|
176
|
-
|
177
|
-
if forward_mode == ForwardMode.DECODE:
|
178
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
179
|
-
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
180
|
-
req_pool_indices[0], seq_lens[0] - 1
|
181
|
-
].item()
|
182
|
-
else:
|
183
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
184
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
185
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
186
|
-
positions = torch.tensor(
|
187
|
-
np.concatenate(
|
188
|
-
[
|
189
|
-
np.arange(
|
190
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
191
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
192
|
-
)
|
193
|
-
for i in range(batch_size)
|
194
|
-
],
|
195
|
-
axis=0,
|
196
|
-
),
|
197
|
-
device="cuda",
|
198
|
-
)
|
199
|
-
other_kv_index = None
|
200
|
-
|
201
|
-
ret = cls(
|
202
|
-
forward_mode=forward_mode,
|
203
|
-
batch_size=batch_size,
|
204
|
-
total_num_tokens=total_num_tokens,
|
205
|
-
max_seq_len=max_seq_len,
|
206
|
-
req_pool_indices=req_pool_indices,
|
207
|
-
start_loc=start_loc,
|
208
|
-
seq_lens=seq_lens,
|
209
|
-
prefix_lens=prefix_lens,
|
210
|
-
positions=positions,
|
211
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
212
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
213
|
-
out_cache_loc=out_cache_loc,
|
214
|
-
out_cache_cont_start=out_cache_cont_start,
|
215
|
-
out_cache_cont_end=out_cache_cont_end,
|
216
|
-
other_kv_index=other_kv_index,
|
217
|
-
return_logprob=return_logprob,
|
218
|
-
top_logprobs_nums=top_logprobs_nums,
|
219
|
-
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
220
|
-
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
221
|
-
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
222
|
-
)
|
223
|
-
|
224
|
-
if forward_mode == ForwardMode.EXTEND:
|
225
|
-
ret.init_extend_args()
|
226
|
-
|
227
|
-
if not global_server_args_dict.get("disable_flashinfer", False):
|
228
|
-
ret.init_flashinfer_args(
|
229
|
-
model_runner.model_config.num_attention_heads // tp_size,
|
230
|
-
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
-
model_runner.model_config.head_dim
|
232
|
-
)
|
233
|
-
|
234
|
-
return ret
|
235
|
-
|
236
31
|
|
237
32
|
class ModelRunner:
|
238
33
|
def __init__(
|
@@ -245,6 +40,7 @@ class ModelRunner:
|
|
245
40
|
nccl_port: int,
|
246
41
|
server_args: ServerArgs,
|
247
42
|
):
|
43
|
+
# Parse args
|
248
44
|
self.model_config = model_config
|
249
45
|
self.mem_fraction_static = mem_fraction_static
|
250
46
|
self.gpu_id = gpu_id
|
@@ -256,10 +52,12 @@ class ModelRunner:
|
|
256
52
|
monkey_patch_vllm_dummy_weight_loader()
|
257
53
|
|
258
54
|
# Init torch distributed
|
259
|
-
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
260
55
|
torch.cuda.set_device(self.gpu_id)
|
261
56
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
262
|
-
|
57
|
+
|
58
|
+
if not server_args.enable_p2p_check:
|
59
|
+
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
60
|
+
|
263
61
|
if server_args.nccl_init_addr:
|
264
62
|
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
|
265
63
|
else:
|
@@ -269,7 +67,7 @@ class ModelRunner:
|
|
269
67
|
world_size=self.tp_size,
|
270
68
|
rank=self.tp_rank,
|
271
69
|
local_rank=self.gpu_id,
|
272
|
-
distributed_init_method=nccl_init_method
|
70
|
+
distributed_init_method=nccl_init_method,
|
273
71
|
)
|
274
72
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
275
73
|
total_gpu_memory = get_available_gpu_memory(
|
@@ -284,11 +82,8 @@ class ModelRunner:
|
|
284
82
|
)
|
285
83
|
|
286
84
|
# Set some global args
|
287
|
-
|
288
|
-
global_server_args_dict =
|
289
|
-
"disable_flashinfer": server_args.disable_flashinfer,
|
290
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
291
|
-
}
|
85
|
+
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
86
|
+
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
292
87
|
|
293
88
|
# Load the model and create memory pool
|
294
89
|
self.load_model()
|
@@ -296,6 +91,9 @@ class ModelRunner:
|
|
296
91
|
self.init_cublas()
|
297
92
|
self.init_flash_infer()
|
298
93
|
|
94
|
+
# Capture cuda graphs
|
95
|
+
self.init_cuda_graphs()
|
96
|
+
|
299
97
|
def load_model(self):
|
300
98
|
logger.info(
|
301
99
|
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
@@ -323,7 +121,7 @@ class ModelRunner:
|
|
323
121
|
device_config=device_config,
|
324
122
|
load_config=load_config,
|
325
123
|
lora_config=None,
|
326
|
-
|
124
|
+
multimodal_config=None,
|
327
125
|
parallel_config=None,
|
328
126
|
scheduler_config=None,
|
329
127
|
cache_config=None,
|
@@ -341,7 +139,13 @@ class ModelRunner:
|
|
341
139
|
)
|
342
140
|
head_dim = self.model_config.head_dim
|
343
141
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
344
|
-
cell_size =
|
142
|
+
cell_size = (
|
143
|
+
head_num
|
144
|
+
* head_dim
|
145
|
+
* self.model_config.num_hidden_layers
|
146
|
+
* 2
|
147
|
+
* torch._utils._element_size(self.dtype)
|
148
|
+
)
|
345
149
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
346
150
|
1 - self.mem_fraction_static
|
347
151
|
)
|
@@ -382,64 +186,60 @@ class ModelRunner:
|
|
382
186
|
return c
|
383
187
|
|
384
188
|
def init_flash_infer(self):
|
385
|
-
if
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
)
|
391
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
189
|
+
if self.server_args.disable_flashinfer:
|
190
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
191
|
+
self.flashinfer_prefill_wrapper_paged = None
|
192
|
+
self.flashinfer_decode_wrapper = None
|
193
|
+
return
|
392
194
|
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
195
|
+
from flashinfer import (
|
196
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
197
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
198
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
199
|
+
)
|
200
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
399
201
|
|
400
|
-
|
401
|
-
|
402
|
-
)
|
403
|
-
|
404
|
-
|
405
|
-
)
|
406
|
-
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
407
|
-
workspace_buffers[1], "NHD"
|
408
|
-
)
|
409
|
-
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
410
|
-
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
|
411
|
-
)
|
202
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
203
|
+
self.model_config.num_attention_heads // self.tp_size,
|
204
|
+
self.model_config.get_num_kv_heads(self.tp_size),
|
205
|
+
):
|
206
|
+
use_tensor_cores = True
|
412
207
|
else:
|
413
|
-
|
414
|
-
self.flashinfer_decode_wrapper = None
|
208
|
+
use_tensor_cores = False
|
415
209
|
|
416
|
-
|
417
|
-
|
418
|
-
input_metadata = InputMetadata.create(
|
419
|
-
self,
|
420
|
-
forward_mode=ForwardMode.PREFILL,
|
421
|
-
tp_size=self.tp_size,
|
422
|
-
req_pool_indices=batch.req_pool_indices,
|
423
|
-
seq_lens=batch.seq_lens,
|
424
|
-
prefix_lens=batch.prefix_lens,
|
425
|
-
position_ids_offsets=batch.position_ids_offsets,
|
426
|
-
out_cache_loc=batch.out_cache_loc,
|
427
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
428
|
-
return_logprob=batch.return_logprob,
|
429
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
430
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
431
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
210
|
+
self.flashinfer_workspace_buffers = torch.empty(
|
211
|
+
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
432
212
|
)
|
433
|
-
|
434
|
-
|
213
|
+
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
214
|
+
self.flashinfer_workspace_buffers[0], "NHD"
|
215
|
+
)
|
216
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
217
|
+
self.flashinfer_workspace_buffers[1], "NHD"
|
435
218
|
)
|
219
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
220
|
+
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
221
|
+
)
|
222
|
+
|
223
|
+
def init_cuda_graphs(self):
|
224
|
+
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
|
225
|
+
|
226
|
+
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
227
|
+
self.cuda_graph_runner = None
|
228
|
+
return
|
229
|
+
|
230
|
+
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
231
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
232
|
+
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
|
233
|
+
self.cuda_graph_runner.capture(batch_size_list)
|
436
234
|
|
437
235
|
@torch.inference_mode()
|
438
|
-
def
|
236
|
+
def forward_decode(self, batch: Batch):
|
237
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
238
|
+
return self.cuda_graph_runner.replay(batch)
|
239
|
+
|
439
240
|
input_metadata = InputMetadata.create(
|
440
241
|
self,
|
441
|
-
forward_mode=ForwardMode.
|
442
|
-
tp_size=self.tp_size,
|
242
|
+
forward_mode=ForwardMode.DECODE,
|
443
243
|
req_pool_indices=batch.req_pool_indices,
|
444
244
|
seq_lens=batch.seq_lens,
|
445
245
|
prefix_lens=batch.prefix_lens,
|
@@ -447,32 +247,23 @@ class ModelRunner:
|
|
447
247
|
out_cache_loc=batch.out_cache_loc,
|
448
248
|
top_logprobs_nums=batch.top_logprobs_nums,
|
449
249
|
return_logprob=batch.return_logprob,
|
450
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
451
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
452
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
453
250
|
)
|
454
251
|
return self.model.forward(
|
455
252
|
batch.input_ids, input_metadata.positions, input_metadata
|
456
253
|
)
|
457
254
|
|
458
255
|
@torch.inference_mode()
|
459
|
-
def
|
256
|
+
def forward_extend(self, batch: Batch):
|
460
257
|
input_metadata = InputMetadata.create(
|
461
258
|
self,
|
462
|
-
forward_mode=ForwardMode.
|
463
|
-
tp_size=self.tp_size,
|
259
|
+
forward_mode=ForwardMode.EXTEND,
|
464
260
|
req_pool_indices=batch.req_pool_indices,
|
465
261
|
seq_lens=batch.seq_lens,
|
466
262
|
prefix_lens=batch.prefix_lens,
|
467
263
|
position_ids_offsets=batch.position_ids_offsets,
|
468
264
|
out_cache_loc=batch.out_cache_loc,
|
469
|
-
out_cache_cont_start=batch.out_cache_cont_start,
|
470
|
-
out_cache_cont_end=batch.out_cache_cont_end,
|
471
265
|
top_logprobs_nums=batch.top_logprobs_nums,
|
472
266
|
return_logprob=batch.return_logprob,
|
473
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
474
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
475
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
476
267
|
)
|
477
268
|
return self.model.forward(
|
478
269
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -483,17 +274,13 @@ class ModelRunner:
|
|
483
274
|
input_metadata = InputMetadata.create(
|
484
275
|
self,
|
485
276
|
forward_mode=ForwardMode.EXTEND,
|
486
|
-
tp_size=self.tp_size,
|
487
277
|
req_pool_indices=batch.req_pool_indices,
|
488
278
|
seq_lens=batch.seq_lens,
|
489
279
|
prefix_lens=batch.prefix_lens,
|
490
280
|
position_ids_offsets=batch.position_ids_offsets,
|
491
281
|
out_cache_loc=batch.out_cache_loc,
|
492
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
493
282
|
return_logprob=batch.return_logprob,
|
494
|
-
|
495
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
496
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
283
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
497
284
|
)
|
498
285
|
return self.model.forward(
|
499
286
|
batch.input_ids,
|
@@ -511,8 +298,6 @@ class ModelRunner:
|
|
511
298
|
return self.forward_decode(batch)
|
512
299
|
elif forward_mode == ForwardMode.EXTEND:
|
513
300
|
return self.forward_extend(batch)
|
514
|
-
elif forward_mode == ForwardMode.PREFILL:
|
515
|
-
return self.forward_prefill(batch)
|
516
301
|
else:
|
517
302
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
518
303
|
|
@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
|
|
34
34
|
from sglang.srt.model_config import ModelConfig
|
35
35
|
from sglang.srt.server_args import ModelPortArgs, ServerArgs
|
36
36
|
from sglang.srt.utils import (
|
37
|
+
connect_rpyc_service,
|
37
38
|
get_int_token_logit_bias,
|
38
39
|
is_multimodal_model,
|
39
40
|
set_random_seed,
|
40
41
|
start_rpyc_service_process,
|
41
|
-
connect_rpyc_service,
|
42
42
|
suppress_other_loggers,
|
43
43
|
)
|
44
44
|
from sglang.utils import get_exception_traceback
|
@@ -98,7 +98,7 @@ class ModelTpServer:
|
|
98
98
|
)
|
99
99
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
100
100
|
self.max_prefill_tokens = (
|
101
|
-
|
101
|
+
8192
|
102
102
|
if server_args.max_prefill_tokens is None
|
103
103
|
else server_args.max_prefill_tokens
|
104
104
|
)
|
@@ -314,11 +314,9 @@ class ModelTpServer:
|
|
314
314
|
self.forward_queue.append(req)
|
315
315
|
|
316
316
|
def get_new_fill_batch(self) -> Optional[Batch]:
|
317
|
-
if
|
318
|
-
|
319
|
-
|
320
|
-
):
|
321
|
-
return None
|
317
|
+
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
|
318
|
+
if running_bs >= self.max_running_requests:
|
319
|
+
return
|
322
320
|
|
323
321
|
# Compute matched prefix length
|
324
322
|
for req in self.forward_queue:
|
@@ -368,9 +366,11 @@ class ModelTpServer:
|
|
368
366
|
if (
|
369
367
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
370
368
|
< available_size
|
371
|
-
and (
|
372
|
-
|
373
|
-
|
369
|
+
and (
|
370
|
+
req.extend_input_len + new_batch_input_tokens
|
371
|
+
<= self.max_prefill_tokens
|
372
|
+
or len(can_run_list) == 0
|
373
|
+
)
|
374
374
|
):
|
375
375
|
delta = self.tree_cache.inc_lock_ref(req.last_node)
|
376
376
|
available_size += delta
|
@@ -392,6 +392,10 @@ class ModelTpServer:
|
|
392
392
|
new_batch_input_tokens += req.extend_input_len
|
393
393
|
else:
|
394
394
|
break
|
395
|
+
|
396
|
+
if running_bs + len(can_run_list) >= self.max_running_requests:
|
397
|
+
break
|
398
|
+
|
395
399
|
if len(can_run_list) == 0:
|
396
400
|
return None
|
397
401
|
|
@@ -452,7 +456,9 @@ class ModelTpServer:
|
|
452
456
|
next_token_ids,
|
453
457
|
].tolist()
|
454
458
|
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
455
|
-
output.normalized_prompt_logprobs =
|
459
|
+
output.normalized_prompt_logprobs = (
|
460
|
+
output.normalized_prompt_logprobs.tolist()
|
461
|
+
)
|
456
462
|
|
457
463
|
next_token_ids = next_token_ids.tolist()
|
458
464
|
else:
|
@@ -582,7 +588,9 @@ class ModelTpServer:
|
|
582
588
|
req.check_finished()
|
583
589
|
|
584
590
|
if req.return_logprob:
|
585
|
-
req.decode_token_logprobs.append(
|
591
|
+
req.decode_token_logprobs.append(
|
592
|
+
(next_token_logprobs[i], next_token_id)
|
593
|
+
)
|
586
594
|
if req.top_logprobs_num > 0:
|
587
595
|
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
588
596
|
|
@@ -759,16 +767,27 @@ class ModelTpClient:
|
|
759
767
|
with ThreadPoolExecutor(self.tp_size) as executor:
|
760
768
|
# Launch model processes
|
761
769
|
if server_args.nnodes == 1:
|
762
|
-
self.procs = list(
|
763
|
-
|
764
|
-
|
765
|
-
|
770
|
+
self.procs = list(
|
771
|
+
executor.map(
|
772
|
+
lambda args: start_rpyc_service_process(*args),
|
773
|
+
[
|
774
|
+
(ModelTpService, p)
|
775
|
+
for p in model_port_args.model_tp_ports
|
776
|
+
],
|
777
|
+
)
|
778
|
+
)
|
766
779
|
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
|
767
780
|
else:
|
768
|
-
addrs = [
|
769
|
-
|
770
|
-
|
771
|
-
|
781
|
+
addrs = [
|
782
|
+
(ip, port)
|
783
|
+
for ip, port in zip(
|
784
|
+
model_port_args.model_tp_ips, model_port_args.model_tp_ports
|
785
|
+
)
|
786
|
+
]
|
787
|
+
|
788
|
+
self.model_services = list(
|
789
|
+
executor.map(lambda args: connect_rpyc_service(*args), addrs)
|
790
|
+
)
|
772
791
|
|
773
792
|
# Init model
|
774
793
|
def init_model(i):
|
@@ -11,7 +11,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
11
11
|
from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
|
12
12
|
from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
|
13
13
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
14
|
-
from sglang.utils import get_exception_traceback, graceful_registry
|
14
|
+
from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
|
15
15
|
|
16
16
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
17
17
|
|
@@ -57,6 +57,8 @@ class DetokenizerManager:
|
|
57
57
|
output_strs = []
|
58
58
|
for i in range(len(recv_obj.rids)):
|
59
59
|
new_text = read_texts[i][len(surr_texts[i]) :]
|
60
|
+
if recv_obj.finished_reason[i] is None:
|
61
|
+
new_text = find_printable_text(new_text)
|
60
62
|
output_strs.append(recv_obj.decoded_texts[i] + new_text)
|
61
63
|
|
62
64
|
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
|
@@ -67,7 +69,7 @@ class DetokenizerManager:
|
|
67
69
|
self.send_to_tokenizer.send_pyobj(
|
68
70
|
BatchStrOut(
|
69
71
|
rids=recv_obj.rids,
|
70
|
-
|
72
|
+
output_strs=output_strs,
|
71
73
|
meta_info=recv_obj.meta_info,
|
72
74
|
finished_reason=recv_obj.finished_reason,
|
73
75
|
)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -316,7 +316,7 @@ class TokenizerManager:
|
|
316
316
|
|
317
317
|
recv_obj.meta_info[i]["id"] = rid
|
318
318
|
out_dict = {
|
319
|
-
"text": recv_obj.
|
319
|
+
"text": recv_obj.output_strs[i],
|
320
320
|
"meta_info": recv_obj.meta_info[i],
|
321
321
|
}
|
322
322
|
state.out_list.append(out_dict)
|
@@ -333,17 +333,18 @@ class TokenizerManager:
|
|
333
333
|
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
334
334
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
335
335
|
)
|
336
|
-
|
337
|
-
|
338
|
-
"
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
"
|
344
|
-
|
345
|
-
|
346
|
-
|
336
|
+
|
337
|
+
if top_logprobs_num > 0:
|
338
|
+
ret["meta_info"][
|
339
|
+
"prefill_top_logprobs"
|
340
|
+
] = self.detokenize_top_logprobs_tokens(
|
341
|
+
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
342
|
+
)
|
343
|
+
ret["meta_info"][
|
344
|
+
"decode_top_logprobs"
|
345
|
+
] = self.detokenize_top_logprobs_tokens(
|
346
|
+
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
347
|
+
)
|
347
348
|
return ret
|
348
349
|
|
349
350
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
@@ -383,7 +384,7 @@ def get_pixel_values(
|
|
383
384
|
try:
|
384
385
|
processor = processor or global_processor
|
385
386
|
image, image_size = load_image(image_data)
|
386
|
-
if image_size
|
387
|
+
if image_size is not None:
|
387
388
|
image_hash = hash(image_data)
|
388
389
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
389
390
|
for _ in range(len(pixel_values)):
|