sglang 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_latency.py +7 -3
- sglang/global_config.py +21 -17
- sglang/srt/layers/radix_attention.py +14 -37
- sglang/srt/layers/token_attention.py +2 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +172 -0
- sglang/srt/managers/controller/infer_batch.py +242 -34
- sglang/srt/managers/controller/model_runner.py +56 -283
- sglang/srt/managers/controller/tp_worker.py +8 -6
- sglang/srt/memory_pool.py +33 -6
- sglang/srt/server.py +1 -0
- sglang/srt/server_args.py +10 -4
- {sglang-0.1.19.dist-info → sglang-0.1.20.dist-info}/METADATA +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.20.dist-info}/RECORD +17 -16
- {sglang-0.1.19.dist-info → sglang-0.1.20.dist-info}/WHEEL +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.20.dist-info}/LICENSE +0 -0
- {sglang-0.1.19.dist-info → sglang-0.1.20.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,9 @@ import importlib
|
|
4
4
|
import importlib.resources
|
5
5
|
import logging
|
6
6
|
import pkgutil
|
7
|
-
from dataclasses import dataclass
|
8
7
|
from functools import lru_cache
|
9
|
-
from typing import
|
8
|
+
from typing import Optional, Type
|
10
9
|
|
11
|
-
import numpy as np
|
12
10
|
import torch
|
13
11
|
import torch.nn as nn
|
14
12
|
from vllm.config import DeviceConfig, LoadConfig
|
@@ -17,7 +15,8 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
|
|
17
15
|
from vllm.model_executor.model_loader import get_model
|
18
16
|
from vllm.model_executor.models import ModelRegistry
|
19
17
|
|
20
|
-
from sglang.
|
18
|
+
from sglang.global_config import global_config
|
19
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
|
21
20
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
22
21
|
from sglang.srt.server_args import ServerArgs
|
23
22
|
from sglang.srt.utils import (
|
@@ -29,210 +28,6 @@ from sglang.srt.utils import (
|
|
29
28
|
|
30
29
|
logger = logging.getLogger("srt.model_runner")
|
31
30
|
|
32
|
-
# for server args in model endpoints
|
33
|
-
global_server_args_dict = {}
|
34
|
-
|
35
|
-
|
36
|
-
@dataclass
|
37
|
-
class InputMetadata:
|
38
|
-
forward_mode: ForwardMode
|
39
|
-
batch_size: int
|
40
|
-
total_num_tokens: int
|
41
|
-
max_seq_len: int
|
42
|
-
req_pool_indices: torch.Tensor
|
43
|
-
start_loc: torch.Tensor
|
44
|
-
seq_lens: torch.Tensor
|
45
|
-
prefix_lens: torch.Tensor
|
46
|
-
positions: torch.Tensor
|
47
|
-
req_to_token_pool: ReqToTokenPool
|
48
|
-
token_to_kv_pool: TokenToKVPool
|
49
|
-
|
50
|
-
# for extend
|
51
|
-
extend_seq_lens: torch.Tensor = None
|
52
|
-
extend_start_loc: torch.Tensor = None
|
53
|
-
max_extend_len: int = 0
|
54
|
-
|
55
|
-
out_cache_loc: torch.Tensor = None
|
56
|
-
out_cache_cont_start: torch.Tensor = None
|
57
|
-
out_cache_cont_end: torch.Tensor = None
|
58
|
-
|
59
|
-
other_kv_index: torch.Tensor = None
|
60
|
-
return_logprob: bool = False
|
61
|
-
top_logprobs_nums: List[int] = None
|
62
|
-
|
63
|
-
# for flashinfer
|
64
|
-
qo_indptr: torch.Tensor = None
|
65
|
-
kv_indptr: torch.Tensor = None
|
66
|
-
kv_indices: torch.Tensor = None
|
67
|
-
kv_last_page_len: torch.Tensor = None
|
68
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
|
-
|
72
|
-
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
73
|
-
if (
|
74
|
-
self.forward_mode == ForwardMode.PREFILL
|
75
|
-
or self.forward_mode == ForwardMode.EXTEND
|
76
|
-
):
|
77
|
-
paged_kernel_lens = self.prefix_lens
|
78
|
-
self.no_prefix = torch.all(self.prefix_lens == 0)
|
79
|
-
else:
|
80
|
-
paged_kernel_lens = self.seq_lens
|
81
|
-
|
82
|
-
self.kv_indptr = torch.zeros(
|
83
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
84
|
-
)
|
85
|
-
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
86
|
-
self.kv_last_page_len = torch.ones(
|
87
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
88
|
-
)
|
89
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
90
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
91
|
-
self.kv_indices = torch.cat(
|
92
|
-
[
|
93
|
-
self.req_to_token_pool.req_to_token[
|
94
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
95
|
-
]
|
96
|
-
for i in range(self.batch_size)
|
97
|
-
],
|
98
|
-
dim=0,
|
99
|
-
).contiguous()
|
100
|
-
|
101
|
-
if (
|
102
|
-
self.forward_mode == ForwardMode.PREFILL
|
103
|
-
or self.forward_mode == ForwardMode.EXTEND
|
104
|
-
):
|
105
|
-
# extend part
|
106
|
-
self.qo_indptr = torch.zeros(
|
107
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
108
|
-
)
|
109
|
-
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
110
|
-
|
111
|
-
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
112
|
-
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
113
|
-
self.qo_indptr,
|
114
|
-
self.qo_indptr.clone(),
|
115
|
-
num_qo_heads,
|
116
|
-
num_kv_heads,
|
117
|
-
head_dim,
|
118
|
-
)
|
119
|
-
|
120
|
-
# cached part
|
121
|
-
self.flashinfer_prefill_wrapper_paged.end_forward()
|
122
|
-
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
123
|
-
self.qo_indptr,
|
124
|
-
self.kv_indptr,
|
125
|
-
self.kv_indices,
|
126
|
-
self.kv_last_page_len,
|
127
|
-
num_qo_heads,
|
128
|
-
num_kv_heads,
|
129
|
-
head_dim,
|
130
|
-
1,
|
131
|
-
)
|
132
|
-
else:
|
133
|
-
self.flashinfer_decode_wrapper.end_forward()
|
134
|
-
self.flashinfer_decode_wrapper.begin_forward(
|
135
|
-
self.kv_indptr,
|
136
|
-
self.kv_indices,
|
137
|
-
self.kv_last_page_len,
|
138
|
-
num_qo_heads,
|
139
|
-
num_kv_heads,
|
140
|
-
head_dim,
|
141
|
-
1,
|
142
|
-
pos_encoding_mode="NONE",
|
143
|
-
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
144
|
-
)
|
145
|
-
|
146
|
-
def init_extend_args(self):
|
147
|
-
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
148
|
-
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
149
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
150
|
-
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
151
|
-
|
152
|
-
@classmethod
|
153
|
-
def create(
|
154
|
-
cls,
|
155
|
-
model_runner,
|
156
|
-
tp_size,
|
157
|
-
forward_mode,
|
158
|
-
req_pool_indices,
|
159
|
-
seq_lens,
|
160
|
-
prefix_lens,
|
161
|
-
position_ids_offsets,
|
162
|
-
out_cache_loc,
|
163
|
-
out_cache_cont_start=None,
|
164
|
-
out_cache_cont_end=None,
|
165
|
-
top_logprobs_nums=None,
|
166
|
-
return_logprob=False,
|
167
|
-
flashinfer_prefill_wrapper_ragged=None,
|
168
|
-
flashinfer_prefill_wrapper_paged=None,
|
169
|
-
flashinfer_decode_wrapper=None,
|
170
|
-
):
|
171
|
-
batch_size = len(req_pool_indices)
|
172
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
173
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
174
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
175
|
-
max_seq_len = int(torch.max(seq_lens))
|
176
|
-
|
177
|
-
if forward_mode == ForwardMode.DECODE:
|
178
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
179
|
-
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
180
|
-
req_pool_indices[0], seq_lens[0] - 1
|
181
|
-
].item()
|
182
|
-
else:
|
183
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
184
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
185
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
186
|
-
positions = torch.tensor(
|
187
|
-
np.concatenate(
|
188
|
-
[
|
189
|
-
np.arange(
|
190
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
191
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
192
|
-
)
|
193
|
-
for i in range(batch_size)
|
194
|
-
],
|
195
|
-
axis=0,
|
196
|
-
),
|
197
|
-
device="cuda",
|
198
|
-
)
|
199
|
-
other_kv_index = None
|
200
|
-
|
201
|
-
ret = cls(
|
202
|
-
forward_mode=forward_mode,
|
203
|
-
batch_size=batch_size,
|
204
|
-
total_num_tokens=total_num_tokens,
|
205
|
-
max_seq_len=max_seq_len,
|
206
|
-
req_pool_indices=req_pool_indices,
|
207
|
-
start_loc=start_loc,
|
208
|
-
seq_lens=seq_lens,
|
209
|
-
prefix_lens=prefix_lens,
|
210
|
-
positions=positions,
|
211
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
212
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
213
|
-
out_cache_loc=out_cache_loc,
|
214
|
-
out_cache_cont_start=out_cache_cont_start,
|
215
|
-
out_cache_cont_end=out_cache_cont_end,
|
216
|
-
other_kv_index=other_kv_index,
|
217
|
-
return_logprob=return_logprob,
|
218
|
-
top_logprobs_nums=top_logprobs_nums,
|
219
|
-
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
220
|
-
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
221
|
-
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
222
|
-
)
|
223
|
-
|
224
|
-
if forward_mode == ForwardMode.EXTEND:
|
225
|
-
ret.init_extend_args()
|
226
|
-
|
227
|
-
if not global_server_args_dict.get("disable_flashinfer", False):
|
228
|
-
ret.init_flashinfer_args(
|
229
|
-
model_runner.model_config.num_attention_heads // tp_size,
|
230
|
-
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
-
model_runner.model_config.head_dim,
|
232
|
-
)
|
233
|
-
|
234
|
-
return ret
|
235
|
-
|
236
31
|
|
237
32
|
class ModelRunner:
|
238
33
|
def __init__(
|
@@ -245,6 +40,7 @@ class ModelRunner:
|
|
245
40
|
nccl_port: int,
|
246
41
|
server_args: ServerArgs,
|
247
42
|
):
|
43
|
+
# Parse args
|
248
44
|
self.model_config = model_config
|
249
45
|
self.mem_fraction_static = mem_fraction_static
|
250
46
|
self.gpu_id = gpu_id
|
@@ -256,7 +52,6 @@ class ModelRunner:
|
|
256
52
|
monkey_patch_vllm_dummy_weight_loader()
|
257
53
|
|
258
54
|
# Init torch distributed
|
259
|
-
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
260
55
|
torch.cuda.set_device(self.gpu_id)
|
261
56
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
262
57
|
|
@@ -287,11 +82,8 @@ class ModelRunner:
|
|
287
82
|
)
|
288
83
|
|
289
84
|
# Set some global args
|
290
|
-
|
291
|
-
global_server_args_dict =
|
292
|
-
"disable_flashinfer": server_args.disable_flashinfer,
|
293
|
-
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
294
|
-
}
|
85
|
+
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
86
|
+
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
295
87
|
|
296
88
|
# Load the model and create memory pool
|
297
89
|
self.load_model()
|
@@ -299,6 +91,9 @@ class ModelRunner:
|
|
299
91
|
self.init_cublas()
|
300
92
|
self.init_flash_infer()
|
301
93
|
|
94
|
+
# Capture cuda graphs
|
95
|
+
self.init_cuda_graphs()
|
96
|
+
|
302
97
|
def load_model(self):
|
303
98
|
logger.info(
|
304
99
|
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
@@ -391,67 +186,60 @@ class ModelRunner:
|
|
391
186
|
return c
|
392
187
|
|
393
188
|
def init_flash_infer(self):
|
394
|
-
if
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
)
|
400
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
189
|
+
if self.server_args.disable_flashinfer:
|
190
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
191
|
+
self.flashinfer_prefill_wrapper_paged = None
|
192
|
+
self.flashinfer_decode_wrapper = None
|
193
|
+
return
|
401
194
|
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
use_tensor_cores = False
|
195
|
+
from flashinfer import (
|
196
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
197
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
198
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
199
|
+
)
|
200
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
409
201
|
|
410
|
-
|
411
|
-
|
412
|
-
)
|
413
|
-
|
414
|
-
|
415
|
-
)
|
416
|
-
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
417
|
-
workspace_buffers[1], "NHD"
|
418
|
-
)
|
419
|
-
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
420
|
-
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
421
|
-
)
|
202
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
203
|
+
self.model_config.num_attention_heads // self.tp_size,
|
204
|
+
self.model_config.get_num_kv_heads(self.tp_size),
|
205
|
+
):
|
206
|
+
use_tensor_cores = True
|
422
207
|
else:
|
423
|
-
|
424
|
-
self.flashinfer_prefill_wrapper_paged
|
425
|
-
) = None
|
426
|
-
self.flashinfer_decode_wrapper = None
|
208
|
+
use_tensor_cores = False
|
427
209
|
|
428
|
-
|
429
|
-
|
430
|
-
input_metadata = InputMetadata.create(
|
431
|
-
self,
|
432
|
-
forward_mode=ForwardMode.PREFILL,
|
433
|
-
tp_size=self.tp_size,
|
434
|
-
req_pool_indices=batch.req_pool_indices,
|
435
|
-
seq_lens=batch.seq_lens,
|
436
|
-
prefix_lens=batch.prefix_lens,
|
437
|
-
position_ids_offsets=batch.position_ids_offsets,
|
438
|
-
out_cache_loc=batch.out_cache_loc,
|
439
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
440
|
-
return_logprob=batch.return_logprob,
|
441
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
442
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
443
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
210
|
+
self.flashinfer_workspace_buffers = torch.empty(
|
211
|
+
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
444
212
|
)
|
445
|
-
|
446
|
-
|
213
|
+
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
214
|
+
self.flashinfer_workspace_buffers[0], "NHD"
|
215
|
+
)
|
216
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
217
|
+
self.flashinfer_workspace_buffers[1], "NHD"
|
447
218
|
)
|
219
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
220
|
+
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
221
|
+
)
|
222
|
+
|
223
|
+
def init_cuda_graphs(self):
|
224
|
+
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
|
225
|
+
|
226
|
+
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
227
|
+
self.cuda_graph_runner = None
|
228
|
+
return
|
229
|
+
|
230
|
+
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
231
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
232
|
+
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
|
233
|
+
self.cuda_graph_runner.capture(batch_size_list)
|
448
234
|
|
449
235
|
@torch.inference_mode()
|
450
|
-
def
|
236
|
+
def forward_decode(self, batch: Batch):
|
237
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
238
|
+
return self.cuda_graph_runner.replay(batch)
|
239
|
+
|
451
240
|
input_metadata = InputMetadata.create(
|
452
241
|
self,
|
453
|
-
forward_mode=ForwardMode.
|
454
|
-
tp_size=self.tp_size,
|
242
|
+
forward_mode=ForwardMode.DECODE,
|
455
243
|
req_pool_indices=batch.req_pool_indices,
|
456
244
|
seq_lens=batch.seq_lens,
|
457
245
|
prefix_lens=batch.prefix_lens,
|
@@ -459,32 +247,23 @@ class ModelRunner:
|
|
459
247
|
out_cache_loc=batch.out_cache_loc,
|
460
248
|
top_logprobs_nums=batch.top_logprobs_nums,
|
461
249
|
return_logprob=batch.return_logprob,
|
462
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
463
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
464
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
465
250
|
)
|
466
251
|
return self.model.forward(
|
467
252
|
batch.input_ids, input_metadata.positions, input_metadata
|
468
253
|
)
|
469
254
|
|
470
255
|
@torch.inference_mode()
|
471
|
-
def
|
256
|
+
def forward_extend(self, batch: Batch):
|
472
257
|
input_metadata = InputMetadata.create(
|
473
258
|
self,
|
474
|
-
forward_mode=ForwardMode.
|
475
|
-
tp_size=self.tp_size,
|
259
|
+
forward_mode=ForwardMode.EXTEND,
|
476
260
|
req_pool_indices=batch.req_pool_indices,
|
477
261
|
seq_lens=batch.seq_lens,
|
478
262
|
prefix_lens=batch.prefix_lens,
|
479
263
|
position_ids_offsets=batch.position_ids_offsets,
|
480
264
|
out_cache_loc=batch.out_cache_loc,
|
481
|
-
out_cache_cont_start=batch.out_cache_cont_start,
|
482
|
-
out_cache_cont_end=batch.out_cache_cont_end,
|
483
265
|
top_logprobs_nums=batch.top_logprobs_nums,
|
484
266
|
return_logprob=batch.return_logprob,
|
485
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
486
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
487
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
488
267
|
)
|
489
268
|
return self.model.forward(
|
490
269
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -495,17 +274,13 @@ class ModelRunner:
|
|
495
274
|
input_metadata = InputMetadata.create(
|
496
275
|
self,
|
497
276
|
forward_mode=ForwardMode.EXTEND,
|
498
|
-
tp_size=self.tp_size,
|
499
277
|
req_pool_indices=batch.req_pool_indices,
|
500
278
|
seq_lens=batch.seq_lens,
|
501
279
|
prefix_lens=batch.prefix_lens,
|
502
280
|
position_ids_offsets=batch.position_ids_offsets,
|
503
281
|
out_cache_loc=batch.out_cache_loc,
|
504
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
505
282
|
return_logprob=batch.return_logprob,
|
506
|
-
|
507
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
508
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
283
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
509
284
|
)
|
510
285
|
return self.model.forward(
|
511
286
|
batch.input_ids,
|
@@ -523,8 +298,6 @@ class ModelRunner:
|
|
523
298
|
return self.forward_decode(batch)
|
524
299
|
elif forward_mode == ForwardMode.EXTEND:
|
525
300
|
return self.forward_extend(batch)
|
526
|
-
elif forward_mode == ForwardMode.PREFILL:
|
527
|
-
return self.forward_prefill(batch)
|
528
301
|
else:
|
529
302
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
530
303
|
|
@@ -98,7 +98,7 @@ class ModelTpServer:
|
|
98
98
|
)
|
99
99
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
100
100
|
self.max_prefill_tokens = (
|
101
|
-
|
101
|
+
8192
|
102
102
|
if server_args.max_prefill_tokens is None
|
103
103
|
else server_args.max_prefill_tokens
|
104
104
|
)
|
@@ -314,11 +314,9 @@ class ModelTpServer:
|
|
314
314
|
self.forward_queue.append(req)
|
315
315
|
|
316
316
|
def get_new_fill_batch(self) -> Optional[Batch]:
|
317
|
-
if
|
318
|
-
|
319
|
-
|
320
|
-
):
|
321
|
-
return None
|
317
|
+
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
|
318
|
+
if running_bs >= self.max_running_requests:
|
319
|
+
return
|
322
320
|
|
323
321
|
# Compute matched prefix length
|
324
322
|
for req in self.forward_queue:
|
@@ -394,6 +392,10 @@ class ModelTpServer:
|
|
394
392
|
new_batch_input_tokens += req.extend_input_len
|
395
393
|
else:
|
396
394
|
break
|
395
|
+
|
396
|
+
if running_bs + len(can_run_list) >= self.max_running_requests:
|
397
|
+
break
|
398
|
+
|
397
399
|
if len(can_run_list) == 0:
|
398
400
|
return None
|
399
401
|
|
sglang/srt/memory_pool.py
CHANGED
@@ -38,15 +38,24 @@ class ReqToTokenPool:
|
|
38
38
|
|
39
39
|
class TokenToKVPool:
|
40
40
|
def __init__(self, size, dtype, head_num, head_dim, layer_num):
|
41
|
-
self.
|
41
|
+
self.size = size
|
42
|
+
# mem_state is the reference counter.
|
43
|
+
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
44
|
+
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
|
42
45
|
self.total_ref_ct = 0
|
43
46
|
|
44
47
|
# [size, key/value, head_num, head_dim] for each layer
|
45
48
|
self.kv_data = [
|
46
|
-
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
49
|
+
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
47
50
|
for _ in range(layer_num)
|
48
51
|
]
|
49
52
|
|
53
|
+
# Prefetch buffer
|
54
|
+
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
55
|
+
self.prefetch_chunk_size = 512
|
56
|
+
|
57
|
+
self.clear()
|
58
|
+
|
50
59
|
def get_key_buffer(self, layer_id):
|
51
60
|
return self.kv_data[layer_id][:, 0]
|
52
61
|
|
@@ -54,14 +63,29 @@ class TokenToKVPool:
|
|
54
63
|
return self.kv_data[layer_id][:, 1]
|
55
64
|
|
56
65
|
def alloc(self, need_size):
|
57
|
-
|
58
|
-
if
|
66
|
+
buffer_len = len(self.prefetch_buffer)
|
67
|
+
if need_size <= buffer_len:
|
68
|
+
select_index = self.prefetch_buffer[:need_size]
|
69
|
+
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
70
|
+
return select_index
|
71
|
+
|
72
|
+
addition_size = need_size - buffer_len
|
73
|
+
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
74
|
+
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
|
75
|
+
|
76
|
+
if select_index.shape[0] < addition_size:
|
59
77
|
return None
|
60
78
|
|
61
79
|
self.add_refs(select_index)
|
62
|
-
|
80
|
+
|
81
|
+
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
82
|
+
ret_index = self.prefetch_buffer[:need_size]
|
83
|
+
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
84
|
+
|
85
|
+
return ret_index
|
63
86
|
|
64
87
|
def alloc_contiguous(self, need_size):
|
88
|
+
# NOTE: This function is deprecated.
|
65
89
|
empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
|
66
90
|
if empty_index.shape[0] < need_size:
|
67
91
|
return None
|
@@ -84,7 +108,7 @@ class TokenToKVPool:
|
|
84
108
|
return len(torch.nonzero(self.mem_state).squeeze(1))
|
85
109
|
|
86
110
|
def available_size(self):
|
87
|
-
return torch.sum(self.mem_state == 0).item()
|
111
|
+
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
|
88
112
|
|
89
113
|
def add_refs(self, token_index: torch.Tensor):
|
90
114
|
self.total_ref_ct += len(token_index)
|
@@ -101,3 +125,6 @@ class TokenToKVPool:
|
|
101
125
|
def clear(self):
|
102
126
|
self.mem_state.fill_(0)
|
103
127
|
self.total_ref_ct = 0
|
128
|
+
|
129
|
+
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
130
|
+
self.add_refs(torch.tensor([0], dtype=torch.int32))
|
sglang/srt/server.py
CHANGED
@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
|
|
146
146
|
|
147
147
|
# Set global environments
|
148
148
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
149
|
+
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
149
150
|
if server_args.show_time_cost:
|
150
151
|
enable_show_time_cost()
|
151
152
|
if server_args.disable_disk_cache:
|
sglang/srt/server_args.py
CHANGED
@@ -29,7 +29,7 @@ class ServerArgs:
|
|
29
29
|
max_prefill_tokens: Optional[int] = None
|
30
30
|
max_running_requests: Optional[int] = None
|
31
31
|
schedule_heuristic: str = "lpm"
|
32
|
-
schedule_conservativeness: float =
|
32
|
+
schedule_conservativeness: float = 0.8
|
33
33
|
|
34
34
|
# Other runtime options
|
35
35
|
tp_size: int = 1
|
@@ -53,6 +53,7 @@ class ServerArgs:
|
|
53
53
|
disable_flashinfer: bool = False
|
54
54
|
disable_radix_cache: bool = False
|
55
55
|
disable_regex_jump_forward: bool = False
|
56
|
+
disable_cuda_graph: bool = False
|
56
57
|
disable_disk_cache: bool = False
|
57
58
|
attention_reduce_in_fp32: bool = False
|
58
59
|
enable_p2p_check: bool = False
|
@@ -67,13 +68,13 @@ class ServerArgs:
|
|
67
68
|
self.tokenizer_path = self.model_path
|
68
69
|
if self.mem_fraction_static is None:
|
69
70
|
if self.tp_size >= 8:
|
70
|
-
self.mem_fraction_static = 0.
|
71
|
+
self.mem_fraction_static = 0.78
|
71
72
|
elif self.tp_size >= 4:
|
72
|
-
self.mem_fraction_static = 0.
|
73
|
+
self.mem_fraction_static = 0.80
|
73
74
|
elif self.tp_size >= 2:
|
74
75
|
self.mem_fraction_static = 0.85
|
75
76
|
else:
|
76
|
-
self.mem_fraction_static = 0.
|
77
|
+
self.mem_fraction_static = 0.88
|
77
78
|
if isinstance(self.additional_ports, int):
|
78
79
|
self.additional_ports = [self.additional_ports]
|
79
80
|
elif self.additional_ports is None:
|
@@ -294,6 +295,11 @@ class ServerArgs:
|
|
294
295
|
action="store_true",
|
295
296
|
help="Disable regex jump-forward",
|
296
297
|
)
|
298
|
+
parser.add_argument(
|
299
|
+
"--disable-cuda-graph",
|
300
|
+
action="store_true",
|
301
|
+
help="Disable cuda graph.",
|
302
|
+
)
|
297
303
|
parser.add_argument(
|
298
304
|
"--disable-disk-cache",
|
299
305
|
action="store_true",
|