sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/bench_latency.py +6 -3
- sglang/global_config.py +22 -16
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/srt/layers/radix_attention.py +14 -37
- sglang/srt/layers/token_attention.py +2 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/infer_batch.py +256 -42
- sglang/srt/managers/controller/manager_multi.py +6 -2
- sglang/srt/managers/controller/manager_single.py +125 -50
- sglang/srt/managers/controller/model_runner.py +69 -284
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +44 -44
- sglang/srt/memory_pool.py +52 -50
- sglang/srt/models/minicpm.py +1 -8
- sglang/srt/models/qwen2_moe.py +126 -107
- sglang/srt/server.py +11 -15
- sglang/srt/server_args.py +12 -4
- sglang/srt/utils.py +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/METADATA +9 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/RECORD +27 -26
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.19.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -4,20 +4,24 @@ import importlib
|
|
4
4
|
import importlib.resources
|
5
5
|
import logging
|
6
6
|
import pkgutil
|
7
|
-
from dataclasses import dataclass
|
8
7
|
from functools import lru_cache
|
9
|
-
from typing import
|
8
|
+
from typing import Optional, Type
|
10
9
|
|
11
|
-
import numpy as np
|
12
10
|
import torch
|
13
11
|
import torch.nn as nn
|
14
12
|
from vllm.config import DeviceConfig, LoadConfig
|
15
13
|
from vllm.config import ModelConfig as VllmModelConfig
|
16
|
-
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
14
|
+
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
|
17
15
|
from vllm.model_executor.model_loader import get_model
|
18
16
|
from vllm.model_executor.models import ModelRegistry
|
19
17
|
|
20
|
-
from sglang.
|
18
|
+
from sglang.global_config import global_config
|
19
|
+
from sglang.srt.managers.controller.infer_batch import (
|
20
|
+
Batch,
|
21
|
+
ForwardMode,
|
22
|
+
InputMetadata,
|
23
|
+
global_server_args_dict,
|
24
|
+
)
|
21
25
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
22
26
|
from sglang.srt.server_args import ServerArgs
|
23
27
|
from sglang.srt.utils import (
|
@@ -29,210 +33,6 @@ from sglang.srt.utils import (
|
|
29
33
|
|
30
34
|
logger = logging.getLogger("srt.model_runner")
|
31
35
|
|
32
|
-
# for server args in model endpoints
|
33
|
-
global_server_args_dict = {}
|
34
|
-
|
35
|
-
|
36
|
-
@dataclass
|
37
|
-
class InputMetadata:
|
38
|
-
forward_mode: ForwardMode
|
39
|
-
batch_size: int
|
40
|
-
total_num_tokens: int
|
41
|
-
max_seq_len: int
|
42
|
-
req_pool_indices: torch.Tensor
|
43
|
-
start_loc: torch.Tensor
|
44
|
-
seq_lens: torch.Tensor
|
45
|
-
prefix_lens: torch.Tensor
|
46
|
-
positions: torch.Tensor
|
47
|
-
req_to_token_pool: ReqToTokenPool
|
48
|
-
token_to_kv_pool: TokenToKVPool
|
49
|
-
|
50
|
-
# for extend
|
51
|
-
extend_seq_lens: torch.Tensor = None
|
52
|
-
extend_start_loc: torch.Tensor = None
|
53
|
-
max_extend_len: int = 0
|
54
|
-
|
55
|
-
out_cache_loc: torch.Tensor = None
|
56
|
-
out_cache_cont_start: torch.Tensor = None
|
57
|
-
out_cache_cont_end: torch.Tensor = None
|
58
|
-
|
59
|
-
other_kv_index: torch.Tensor = None
|
60
|
-
return_logprob: bool = False
|
61
|
-
top_logprobs_nums: List[int] = None
|
62
|
-
|
63
|
-
# for flashinfer
|
64
|
-
qo_indptr: torch.Tensor = None
|
65
|
-
kv_indptr: torch.Tensor = None
|
66
|
-
kv_indices: torch.Tensor = None
|
67
|
-
kv_last_page_len: torch.Tensor = None
|
68
|
-
flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
|
69
|
-
flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
|
70
|
-
flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
|
71
|
-
|
72
|
-
def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
|
73
|
-
if (
|
74
|
-
self.forward_mode == ForwardMode.PREFILL
|
75
|
-
or self.forward_mode == ForwardMode.EXTEND
|
76
|
-
):
|
77
|
-
paged_kernel_lens = self.prefix_lens
|
78
|
-
self.no_prefix = torch.all(self.prefix_lens == 0)
|
79
|
-
else:
|
80
|
-
paged_kernel_lens = self.seq_lens
|
81
|
-
|
82
|
-
self.kv_indptr = torch.zeros(
|
83
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
84
|
-
)
|
85
|
-
self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
86
|
-
self.kv_last_page_len = torch.ones(
|
87
|
-
(self.batch_size,), dtype=torch.int32, device="cuda"
|
88
|
-
)
|
89
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
90
|
-
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
91
|
-
self.kv_indices = torch.cat(
|
92
|
-
[
|
93
|
-
self.req_to_token_pool.req_to_token[
|
94
|
-
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
|
95
|
-
]
|
96
|
-
for i in range(self.batch_size)
|
97
|
-
],
|
98
|
-
dim=0,
|
99
|
-
).contiguous()
|
100
|
-
|
101
|
-
if (
|
102
|
-
self.forward_mode == ForwardMode.PREFILL
|
103
|
-
or self.forward_mode == ForwardMode.EXTEND
|
104
|
-
):
|
105
|
-
# extend part
|
106
|
-
self.qo_indptr = torch.zeros(
|
107
|
-
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
108
|
-
)
|
109
|
-
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
|
110
|
-
|
111
|
-
self.flashinfer_prefill_wrapper_ragged.end_forward()
|
112
|
-
self.flashinfer_prefill_wrapper_ragged.begin_forward(
|
113
|
-
self.qo_indptr,
|
114
|
-
self.qo_indptr.clone(),
|
115
|
-
num_qo_heads,
|
116
|
-
num_kv_heads,
|
117
|
-
head_dim,
|
118
|
-
)
|
119
|
-
|
120
|
-
# cached part
|
121
|
-
self.flashinfer_prefill_wrapper_paged.end_forward()
|
122
|
-
self.flashinfer_prefill_wrapper_paged.begin_forward(
|
123
|
-
self.qo_indptr,
|
124
|
-
self.kv_indptr,
|
125
|
-
self.kv_indices,
|
126
|
-
self.kv_last_page_len,
|
127
|
-
num_qo_heads,
|
128
|
-
num_kv_heads,
|
129
|
-
head_dim,
|
130
|
-
1,
|
131
|
-
)
|
132
|
-
else:
|
133
|
-
self.flashinfer_decode_wrapper.end_forward()
|
134
|
-
self.flashinfer_decode_wrapper.begin_forward(
|
135
|
-
self.kv_indptr,
|
136
|
-
self.kv_indices,
|
137
|
-
self.kv_last_page_len,
|
138
|
-
num_qo_heads,
|
139
|
-
num_kv_heads,
|
140
|
-
head_dim,
|
141
|
-
1,
|
142
|
-
pos_encoding_mode="NONE",
|
143
|
-
data_type=self.token_to_kv_pool.kv_data[0].dtype,
|
144
|
-
)
|
145
|
-
|
146
|
-
def init_extend_args(self):
|
147
|
-
self.extend_seq_lens = self.seq_lens - self.prefix_lens
|
148
|
-
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
149
|
-
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
150
|
-
self.max_extend_len = int(torch.max(self.extend_seq_lens))
|
151
|
-
|
152
|
-
@classmethod
|
153
|
-
def create(
|
154
|
-
cls,
|
155
|
-
model_runner,
|
156
|
-
tp_size,
|
157
|
-
forward_mode,
|
158
|
-
req_pool_indices,
|
159
|
-
seq_lens,
|
160
|
-
prefix_lens,
|
161
|
-
position_ids_offsets,
|
162
|
-
out_cache_loc,
|
163
|
-
out_cache_cont_start=None,
|
164
|
-
out_cache_cont_end=None,
|
165
|
-
top_logprobs_nums=None,
|
166
|
-
return_logprob=False,
|
167
|
-
flashinfer_prefill_wrapper_ragged=None,
|
168
|
-
flashinfer_prefill_wrapper_paged=None,
|
169
|
-
flashinfer_decode_wrapper=None,
|
170
|
-
):
|
171
|
-
batch_size = len(req_pool_indices)
|
172
|
-
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
173
|
-
start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
|
174
|
-
total_num_tokens = int(torch.sum(seq_lens))
|
175
|
-
max_seq_len = int(torch.max(seq_lens))
|
176
|
-
|
177
|
-
if forward_mode == ForwardMode.DECODE:
|
178
|
-
positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
|
179
|
-
other_kv_index = model_runner.req_to_token_pool.req_to_token[
|
180
|
-
req_pool_indices[0], seq_lens[0] - 1
|
181
|
-
].item()
|
182
|
-
else:
|
183
|
-
seq_lens_cpu = seq_lens.cpu().numpy()
|
184
|
-
prefix_lens_cpu = prefix_lens.cpu().numpy()
|
185
|
-
position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
|
186
|
-
positions = torch.tensor(
|
187
|
-
np.concatenate(
|
188
|
-
[
|
189
|
-
np.arange(
|
190
|
-
prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
|
191
|
-
seq_lens_cpu[i] + position_ids_offsets_cpu[i],
|
192
|
-
)
|
193
|
-
for i in range(batch_size)
|
194
|
-
],
|
195
|
-
axis=0,
|
196
|
-
),
|
197
|
-
device="cuda",
|
198
|
-
)
|
199
|
-
other_kv_index = None
|
200
|
-
|
201
|
-
ret = cls(
|
202
|
-
forward_mode=forward_mode,
|
203
|
-
batch_size=batch_size,
|
204
|
-
total_num_tokens=total_num_tokens,
|
205
|
-
max_seq_len=max_seq_len,
|
206
|
-
req_pool_indices=req_pool_indices,
|
207
|
-
start_loc=start_loc,
|
208
|
-
seq_lens=seq_lens,
|
209
|
-
prefix_lens=prefix_lens,
|
210
|
-
positions=positions,
|
211
|
-
req_to_token_pool=model_runner.req_to_token_pool,
|
212
|
-
token_to_kv_pool=model_runner.token_to_kv_pool,
|
213
|
-
out_cache_loc=out_cache_loc,
|
214
|
-
out_cache_cont_start=out_cache_cont_start,
|
215
|
-
out_cache_cont_end=out_cache_cont_end,
|
216
|
-
other_kv_index=other_kv_index,
|
217
|
-
return_logprob=return_logprob,
|
218
|
-
top_logprobs_nums=top_logprobs_nums,
|
219
|
-
flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
|
220
|
-
flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
|
221
|
-
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
222
|
-
)
|
223
|
-
|
224
|
-
if forward_mode == ForwardMode.EXTEND:
|
225
|
-
ret.init_extend_args()
|
226
|
-
|
227
|
-
if not global_server_args_dict.get("disable_flashinfer", False):
|
228
|
-
ret.init_flashinfer_args(
|
229
|
-
model_runner.model_config.num_attention_heads // tp_size,
|
230
|
-
model_runner.model_config.get_num_kv_heads(tp_size),
|
231
|
-
model_runner.model_config.head_dim,
|
232
|
-
)
|
233
|
-
|
234
|
-
return ret
|
235
|
-
|
236
36
|
|
237
37
|
class ModelRunner:
|
238
38
|
def __init__(
|
@@ -245,6 +45,7 @@ class ModelRunner:
|
|
245
45
|
nccl_port: int,
|
246
46
|
server_args: ServerArgs,
|
247
47
|
):
|
48
|
+
# Parse args
|
248
49
|
self.model_config = model_config
|
249
50
|
self.mem_fraction_static = mem_fraction_static
|
250
51
|
self.gpu_id = gpu_id
|
@@ -256,7 +57,6 @@ class ModelRunner:
|
|
256
57
|
monkey_patch_vllm_dummy_weight_loader()
|
257
58
|
|
258
59
|
# Init torch distributed
|
259
|
-
logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
|
260
60
|
torch.cuda.set_device(self.gpu_id)
|
261
61
|
logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
|
262
62
|
|
@@ -275,6 +75,7 @@ class ModelRunner:
|
|
275
75
|
distributed_init_method=nccl_init_method,
|
276
76
|
)
|
277
77
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
78
|
+
self.tp_group = get_tp_group()
|
278
79
|
total_gpu_memory = get_available_gpu_memory(
|
279
80
|
self.gpu_id, distributed=self.tp_size > 1
|
280
81
|
)
|
@@ -287,11 +88,10 @@ class ModelRunner:
|
|
287
88
|
)
|
288
89
|
|
289
90
|
# Set some global args
|
290
|
-
|
291
|
-
global_server_args_dict
|
292
|
-
"
|
293
|
-
|
294
|
-
}
|
91
|
+
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
92
|
+
global_server_args_dict[
|
93
|
+
"attention_reduce_in_fp32"
|
94
|
+
] = server_args.attention_reduce_in_fp32
|
295
95
|
|
296
96
|
# Load the model and create memory pool
|
297
97
|
self.load_model()
|
@@ -299,6 +99,9 @@ class ModelRunner:
|
|
299
99
|
self.init_cublas()
|
300
100
|
self.init_flash_infer()
|
301
101
|
|
102
|
+
# Capture cuda graphs
|
103
|
+
self.init_cuda_graphs()
|
104
|
+
|
302
105
|
def load_model(self):
|
303
106
|
logger.info(
|
304
107
|
f"[gpu_id={self.gpu_id}] Load weight begin. "
|
@@ -391,67 +194,64 @@ class ModelRunner:
|
|
391
194
|
return c
|
392
195
|
|
393
196
|
def init_flash_infer(self):
|
394
|
-
if
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
)
|
400
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
197
|
+
if self.server_args.disable_flashinfer:
|
198
|
+
self.flashinfer_prefill_wrapper_ragged = None
|
199
|
+
self.flashinfer_prefill_wrapper_paged = None
|
200
|
+
self.flashinfer_decode_wrapper = None
|
201
|
+
return
|
401
202
|
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
use_tensor_cores = False
|
203
|
+
from flashinfer import (
|
204
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
205
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
206
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
207
|
+
)
|
208
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
409
209
|
|
410
|
-
|
411
|
-
|
412
|
-
)
|
413
|
-
|
414
|
-
|
415
|
-
)
|
416
|
-
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
417
|
-
workspace_buffers[1], "NHD"
|
418
|
-
)
|
419
|
-
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
420
|
-
workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
|
421
|
-
)
|
210
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
211
|
+
self.model_config.num_attention_heads // self.tp_size,
|
212
|
+
self.model_config.get_num_kv_heads(self.tp_size),
|
213
|
+
):
|
214
|
+
use_tensor_cores = True
|
422
215
|
else:
|
423
|
-
|
424
|
-
self.flashinfer_prefill_wrapper_paged
|
425
|
-
) = None
|
426
|
-
self.flashinfer_decode_wrapper = None
|
216
|
+
use_tensor_cores = False
|
427
217
|
|
428
|
-
|
429
|
-
|
430
|
-
input_metadata = InputMetadata.create(
|
431
|
-
self,
|
432
|
-
forward_mode=ForwardMode.PREFILL,
|
433
|
-
tp_size=self.tp_size,
|
434
|
-
req_pool_indices=batch.req_pool_indices,
|
435
|
-
seq_lens=batch.seq_lens,
|
436
|
-
prefix_lens=batch.prefix_lens,
|
437
|
-
position_ids_offsets=batch.position_ids_offsets,
|
438
|
-
out_cache_loc=batch.out_cache_loc,
|
439
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
440
|
-
return_logprob=batch.return_logprob,
|
441
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
442
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
443
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
218
|
+
self.flashinfer_workspace_buffers = torch.empty(
|
219
|
+
2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
|
444
220
|
)
|
445
|
-
|
446
|
-
|
221
|
+
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
222
|
+
self.flashinfer_workspace_buffers[0], "NHD"
|
223
|
+
)
|
224
|
+
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
225
|
+
self.flashinfer_workspace_buffers[1], "NHD"
|
226
|
+
)
|
227
|
+
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
228
|
+
self.flashinfer_workspace_buffers[0],
|
229
|
+
"NHD",
|
230
|
+
use_tensor_cores=use_tensor_cores,
|
447
231
|
)
|
448
232
|
|
233
|
+
def init_cuda_graphs(self):
|
234
|
+
from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
|
235
|
+
|
236
|
+
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
237
|
+
self.cuda_graph_runner = None
|
238
|
+
return
|
239
|
+
|
240
|
+
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
241
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
242
|
+
self.cuda_graph_runner = CudaGraphRunner(
|
243
|
+
self, max_batch_size_to_capture=max(batch_size_list)
|
244
|
+
)
|
245
|
+
self.cuda_graph_runner.capture(batch_size_list)
|
246
|
+
|
449
247
|
@torch.inference_mode()
|
450
|
-
def
|
248
|
+
def forward_decode(self, batch: Batch):
|
249
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
250
|
+
return self.cuda_graph_runner.replay(batch)
|
251
|
+
|
451
252
|
input_metadata = InputMetadata.create(
|
452
253
|
self,
|
453
|
-
forward_mode=ForwardMode.
|
454
|
-
tp_size=self.tp_size,
|
254
|
+
forward_mode=ForwardMode.DECODE,
|
455
255
|
req_pool_indices=batch.req_pool_indices,
|
456
256
|
seq_lens=batch.seq_lens,
|
457
257
|
prefix_lens=batch.prefix_lens,
|
@@ -459,32 +259,23 @@ class ModelRunner:
|
|
459
259
|
out_cache_loc=batch.out_cache_loc,
|
460
260
|
top_logprobs_nums=batch.top_logprobs_nums,
|
461
261
|
return_logprob=batch.return_logprob,
|
462
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
463
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
464
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
465
262
|
)
|
466
263
|
return self.model.forward(
|
467
264
|
batch.input_ids, input_metadata.positions, input_metadata
|
468
265
|
)
|
469
266
|
|
470
267
|
@torch.inference_mode()
|
471
|
-
def
|
268
|
+
def forward_extend(self, batch: Batch):
|
472
269
|
input_metadata = InputMetadata.create(
|
473
270
|
self,
|
474
|
-
forward_mode=ForwardMode.
|
475
|
-
tp_size=self.tp_size,
|
271
|
+
forward_mode=ForwardMode.EXTEND,
|
476
272
|
req_pool_indices=batch.req_pool_indices,
|
477
273
|
seq_lens=batch.seq_lens,
|
478
274
|
prefix_lens=batch.prefix_lens,
|
479
275
|
position_ids_offsets=batch.position_ids_offsets,
|
480
276
|
out_cache_loc=batch.out_cache_loc,
|
481
|
-
out_cache_cont_start=batch.out_cache_cont_start,
|
482
|
-
out_cache_cont_end=batch.out_cache_cont_end,
|
483
277
|
top_logprobs_nums=batch.top_logprobs_nums,
|
484
278
|
return_logprob=batch.return_logprob,
|
485
|
-
flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
|
486
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
487
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
488
279
|
)
|
489
280
|
return self.model.forward(
|
490
281
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -495,17 +286,13 @@ class ModelRunner:
|
|
495
286
|
input_metadata = InputMetadata.create(
|
496
287
|
self,
|
497
288
|
forward_mode=ForwardMode.EXTEND,
|
498
|
-
tp_size=self.tp_size,
|
499
289
|
req_pool_indices=batch.req_pool_indices,
|
500
290
|
seq_lens=batch.seq_lens,
|
501
291
|
prefix_lens=batch.prefix_lens,
|
502
292
|
position_ids_offsets=batch.position_ids_offsets,
|
503
293
|
out_cache_loc=batch.out_cache_loc,
|
504
|
-
top_logprobs_nums=batch.top_logprobs_nums,
|
505
294
|
return_logprob=batch.return_logprob,
|
506
|
-
|
507
|
-
flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
|
508
|
-
flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
|
295
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
509
296
|
)
|
510
297
|
return self.model.forward(
|
511
298
|
batch.input_ids,
|
@@ -523,8 +310,6 @@ class ModelRunner:
|
|
523
310
|
return self.forward_decode(batch)
|
524
311
|
elif forward_mode == ForwardMode.EXTEND:
|
525
312
|
return self.forward_extend(batch)
|
526
|
-
elif forward_mode == ForwardMode.PREFILL:
|
527
|
-
return self.forward_prefill(batch)
|
528
313
|
else:
|
529
314
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
530
315
|
|
@@ -82,12 +82,12 @@ class RadixCache:
|
|
82
82
|
|
83
83
|
if self.disable:
|
84
84
|
if del_in_memory_pool:
|
85
|
-
self.token_to_kv_pool.
|
85
|
+
self.token_to_kv_pool.free(indices)
|
86
86
|
else:
|
87
87
|
return torch.tensor([], dtype=torch.int64), self.root_node
|
88
88
|
|
89
89
|
# Radix Cache takes one ref in memory pool
|
90
|
-
self.token_to_kv_pool.
|
90
|
+
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
|
91
91
|
|
92
92
|
if del_in_memory_pool:
|
93
93
|
self.req_to_token_pool.free(req_pool_idx)
|
@@ -125,7 +125,8 @@ class RadixCache:
|
|
125
125
|
if x.lock_ref > 0:
|
126
126
|
continue
|
127
127
|
|
128
|
-
|
128
|
+
evict_callback(x.value)
|
129
|
+
num_evicted += len(x.value)
|
129
130
|
self._delete_leaf(x)
|
130
131
|
|
131
132
|
if len(x.parent.children) == 0:
|
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
|
|
13
13
|
max_total_num_tokens,
|
14
14
|
tree_cache,
|
15
15
|
):
|
16
|
+
if tree_cache.disable and schedule_heuristic == "lpm":
|
17
|
+
# LMP is not meaningless when tree cache is disabled.
|
18
|
+
schedule_heuristic = "fcfs"
|
19
|
+
|
16
20
|
self.schedule_heuristic = schedule_heuristic
|
17
21
|
self.max_running_seqs = max_running_seqs
|
18
22
|
self.max_prefill_num_tokens = max_prefill_num_tokens
|