sglang 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,20 +4,24 @@ import importlib
4
4
  import importlib.resources
5
5
  import logging
6
6
  import pkgutil
7
- from dataclasses import dataclass
8
7
  from functools import lru_cache
9
- from typing import List, Optional, Type
8
+ from typing import Optional, Type
10
9
 
11
- import numpy as np
12
10
  import torch
13
11
  import torch.nn as nn
14
12
  from vllm.config import DeviceConfig, LoadConfig
15
13
  from vllm.config import ModelConfig as VllmModelConfig
16
- from vllm.distributed import init_distributed_environment, initialize_model_parallel
14
+ from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
17
15
  from vllm.model_executor.model_loader import get_model
18
16
  from vllm.model_executor.models import ModelRegistry
19
17
 
20
- from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.managers.controller.infer_batch import (
20
+ Batch,
21
+ ForwardMode,
22
+ InputMetadata,
23
+ global_server_args_dict,
24
+ )
21
25
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
22
26
  from sglang.srt.server_args import ServerArgs
23
27
  from sglang.srt.utils import (
@@ -29,210 +33,6 @@ from sglang.srt.utils import (
29
33
 
30
34
  logger = logging.getLogger("srt.model_runner")
31
35
 
32
- # for server args in model endpoints
33
- global_server_args_dict = {}
34
-
35
-
36
- @dataclass
37
- class InputMetadata:
38
- forward_mode: ForwardMode
39
- batch_size: int
40
- total_num_tokens: int
41
- max_seq_len: int
42
- req_pool_indices: torch.Tensor
43
- start_loc: torch.Tensor
44
- seq_lens: torch.Tensor
45
- prefix_lens: torch.Tensor
46
- positions: torch.Tensor
47
- req_to_token_pool: ReqToTokenPool
48
- token_to_kv_pool: TokenToKVPool
49
-
50
- # for extend
51
- extend_seq_lens: torch.Tensor = None
52
- extend_start_loc: torch.Tensor = None
53
- max_extend_len: int = 0
54
-
55
- out_cache_loc: torch.Tensor = None
56
- out_cache_cont_start: torch.Tensor = None
57
- out_cache_cont_end: torch.Tensor = None
58
-
59
- other_kv_index: torch.Tensor = None
60
- return_logprob: bool = False
61
- top_logprobs_nums: List[int] = None
62
-
63
- # for flashinfer
64
- qo_indptr: torch.Tensor = None
65
- kv_indptr: torch.Tensor = None
66
- kv_indices: torch.Tensor = None
67
- kv_last_page_len: torch.Tensor = None
68
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
-
72
- def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
73
- if (
74
- self.forward_mode == ForwardMode.PREFILL
75
- or self.forward_mode == ForwardMode.EXTEND
76
- ):
77
- paged_kernel_lens = self.prefix_lens
78
- self.no_prefix = torch.all(self.prefix_lens == 0)
79
- else:
80
- paged_kernel_lens = self.seq_lens
81
-
82
- self.kv_indptr = torch.zeros(
83
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
84
- )
85
- self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
86
- self.kv_last_page_len = torch.ones(
87
- (self.batch_size,), dtype=torch.int32, device="cuda"
88
- )
89
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
90
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
91
- self.kv_indices = torch.cat(
92
- [
93
- self.req_to_token_pool.req_to_token[
94
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
95
- ]
96
- for i in range(self.batch_size)
97
- ],
98
- dim=0,
99
- ).contiguous()
100
-
101
- if (
102
- self.forward_mode == ForwardMode.PREFILL
103
- or self.forward_mode == ForwardMode.EXTEND
104
- ):
105
- # extend part
106
- self.qo_indptr = torch.zeros(
107
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
108
- )
109
- self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
110
-
111
- self.flashinfer_prefill_wrapper_ragged.end_forward()
112
- self.flashinfer_prefill_wrapper_ragged.begin_forward(
113
- self.qo_indptr,
114
- self.qo_indptr.clone(),
115
- num_qo_heads,
116
- num_kv_heads,
117
- head_dim,
118
- )
119
-
120
- # cached part
121
- self.flashinfer_prefill_wrapper_paged.end_forward()
122
- self.flashinfer_prefill_wrapper_paged.begin_forward(
123
- self.qo_indptr,
124
- self.kv_indptr,
125
- self.kv_indices,
126
- self.kv_last_page_len,
127
- num_qo_heads,
128
- num_kv_heads,
129
- head_dim,
130
- 1,
131
- )
132
- else:
133
- self.flashinfer_decode_wrapper.end_forward()
134
- self.flashinfer_decode_wrapper.begin_forward(
135
- self.kv_indptr,
136
- self.kv_indices,
137
- self.kv_last_page_len,
138
- num_qo_heads,
139
- num_kv_heads,
140
- head_dim,
141
- 1,
142
- pos_encoding_mode="NONE",
143
- data_type=self.token_to_kv_pool.kv_data[0].dtype,
144
- )
145
-
146
- def init_extend_args(self):
147
- self.extend_seq_lens = self.seq_lens - self.prefix_lens
148
- self.extend_start_loc = torch.zeros_like(self.seq_lens)
149
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
150
- self.max_extend_len = int(torch.max(self.extend_seq_lens))
151
-
152
- @classmethod
153
- def create(
154
- cls,
155
- model_runner,
156
- tp_size,
157
- forward_mode,
158
- req_pool_indices,
159
- seq_lens,
160
- prefix_lens,
161
- position_ids_offsets,
162
- out_cache_loc,
163
- out_cache_cont_start=None,
164
- out_cache_cont_end=None,
165
- top_logprobs_nums=None,
166
- return_logprob=False,
167
- flashinfer_prefill_wrapper_ragged=None,
168
- flashinfer_prefill_wrapper_paged=None,
169
- flashinfer_decode_wrapper=None,
170
- ):
171
- batch_size = len(req_pool_indices)
172
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
173
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
174
- total_num_tokens = int(torch.sum(seq_lens))
175
- max_seq_len = int(torch.max(seq_lens))
176
-
177
- if forward_mode == ForwardMode.DECODE:
178
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
179
- other_kv_index = model_runner.req_to_token_pool.req_to_token[
180
- req_pool_indices[0], seq_lens[0] - 1
181
- ].item()
182
- else:
183
- seq_lens_cpu = seq_lens.cpu().numpy()
184
- prefix_lens_cpu = prefix_lens.cpu().numpy()
185
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
186
- positions = torch.tensor(
187
- np.concatenate(
188
- [
189
- np.arange(
190
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
191
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
192
- )
193
- for i in range(batch_size)
194
- ],
195
- axis=0,
196
- ),
197
- device="cuda",
198
- )
199
- other_kv_index = None
200
-
201
- ret = cls(
202
- forward_mode=forward_mode,
203
- batch_size=batch_size,
204
- total_num_tokens=total_num_tokens,
205
- max_seq_len=max_seq_len,
206
- req_pool_indices=req_pool_indices,
207
- start_loc=start_loc,
208
- seq_lens=seq_lens,
209
- prefix_lens=prefix_lens,
210
- positions=positions,
211
- req_to_token_pool=model_runner.req_to_token_pool,
212
- token_to_kv_pool=model_runner.token_to_kv_pool,
213
- out_cache_loc=out_cache_loc,
214
- out_cache_cont_start=out_cache_cont_start,
215
- out_cache_cont_end=out_cache_cont_end,
216
- other_kv_index=other_kv_index,
217
- return_logprob=return_logprob,
218
- top_logprobs_nums=top_logprobs_nums,
219
- flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
220
- flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
221
- flashinfer_decode_wrapper=flashinfer_decode_wrapper,
222
- )
223
-
224
- if forward_mode == ForwardMode.EXTEND:
225
- ret.init_extend_args()
226
-
227
- if not global_server_args_dict.get("disable_flashinfer", False):
228
- ret.init_flashinfer_args(
229
- model_runner.model_config.num_attention_heads // tp_size,
230
- model_runner.model_config.get_num_kv_heads(tp_size),
231
- model_runner.model_config.head_dim,
232
- )
233
-
234
- return ret
235
-
236
36
 
237
37
  class ModelRunner:
238
38
  def __init__(
@@ -245,6 +45,7 @@ class ModelRunner:
245
45
  nccl_port: int,
246
46
  server_args: ServerArgs,
247
47
  ):
48
+ # Parse args
248
49
  self.model_config = model_config
249
50
  self.mem_fraction_static = mem_fraction_static
250
51
  self.gpu_id = gpu_id
@@ -256,7 +57,6 @@ class ModelRunner:
256
57
  monkey_patch_vllm_dummy_weight_loader()
257
58
 
258
59
  # Init torch distributed
259
- logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
260
60
  torch.cuda.set_device(self.gpu_id)
261
61
  logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
262
62
 
@@ -275,6 +75,7 @@ class ModelRunner:
275
75
  distributed_init_method=nccl_init_method,
276
76
  )
277
77
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
78
+ self.tp_group = get_tp_group()
278
79
  total_gpu_memory = get_available_gpu_memory(
279
80
  self.gpu_id, distributed=self.tp_size > 1
280
81
  )
@@ -287,11 +88,10 @@ class ModelRunner:
287
88
  )
288
89
 
289
90
  # Set some global args
290
- global global_server_args_dict
291
- global_server_args_dict = {
292
- "disable_flashinfer": server_args.disable_flashinfer,
293
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
294
- }
91
+ global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
92
+ global_server_args_dict[
93
+ "attention_reduce_in_fp32"
94
+ ] = server_args.attention_reduce_in_fp32
295
95
 
296
96
  # Load the model and create memory pool
297
97
  self.load_model()
@@ -299,6 +99,9 @@ class ModelRunner:
299
99
  self.init_cublas()
300
100
  self.init_flash_infer()
301
101
 
102
+ # Capture cuda graphs
103
+ self.init_cuda_graphs()
104
+
302
105
  def load_model(self):
303
106
  logger.info(
304
107
  f"[gpu_id={self.gpu_id}] Load weight begin. "
@@ -391,67 +194,64 @@ class ModelRunner:
391
194
  return c
392
195
 
393
196
  def init_flash_infer(self):
394
- if not global_server_args_dict.get("disable_flashinfer", False):
395
- from flashinfer import (
396
- BatchDecodeWithPagedKVCacheWrapper,
397
- BatchPrefillWithPagedKVCacheWrapper,
398
- BatchPrefillWithRaggedKVCacheWrapper,
399
- )
400
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
197
+ if self.server_args.disable_flashinfer:
198
+ self.flashinfer_prefill_wrapper_ragged = None
199
+ self.flashinfer_prefill_wrapper_paged = None
200
+ self.flashinfer_decode_wrapper = None
201
+ return
401
202
 
402
- if not _grouped_size_compiled_for_decode_kernels(
403
- self.model_config.num_attention_heads // self.tp_size,
404
- self.model_config.get_num_kv_heads(self.tp_size),
405
- ):
406
- use_tensor_cores = True
407
- else:
408
- use_tensor_cores = False
203
+ from flashinfer import (
204
+ BatchDecodeWithPagedKVCacheWrapper,
205
+ BatchPrefillWithPagedKVCacheWrapper,
206
+ BatchPrefillWithRaggedKVCacheWrapper,
207
+ )
208
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
409
209
 
410
- workspace_buffers = torch.empty(
411
- 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
412
- )
413
- self.flashinfer_prefill_wrapper_ragged = (
414
- BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
415
- )
416
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
417
- workspace_buffers[1], "NHD"
418
- )
419
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
420
- workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
421
- )
210
+ if not _grouped_size_compiled_for_decode_kernels(
211
+ self.model_config.num_attention_heads // self.tp_size,
212
+ self.model_config.get_num_kv_heads(self.tp_size),
213
+ ):
214
+ use_tensor_cores = True
422
215
  else:
423
- self.flashinfer_prefill_wrapper_ragged = (
424
- self.flashinfer_prefill_wrapper_paged
425
- ) = None
426
- self.flashinfer_decode_wrapper = None
216
+ use_tensor_cores = False
427
217
 
428
- @torch.inference_mode()
429
- def forward_prefill(self, batch: Batch):
430
- input_metadata = InputMetadata.create(
431
- self,
432
- forward_mode=ForwardMode.PREFILL,
433
- tp_size=self.tp_size,
434
- req_pool_indices=batch.req_pool_indices,
435
- seq_lens=batch.seq_lens,
436
- prefix_lens=batch.prefix_lens,
437
- position_ids_offsets=batch.position_ids_offsets,
438
- out_cache_loc=batch.out_cache_loc,
439
- top_logprobs_nums=batch.top_logprobs_nums,
440
- return_logprob=batch.return_logprob,
441
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
442
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
443
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
218
+ self.flashinfer_workspace_buffers = torch.empty(
219
+ 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
444
220
  )
445
- return self.model.forward(
446
- batch.input_ids, input_metadata.positions, input_metadata
221
+ self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
222
+ self.flashinfer_workspace_buffers[0], "NHD"
223
+ )
224
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
225
+ self.flashinfer_workspace_buffers[1], "NHD"
226
+ )
227
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
228
+ self.flashinfer_workspace_buffers[0],
229
+ "NHD",
230
+ use_tensor_cores=use_tensor_cores,
447
231
  )
448
232
 
233
+ def init_cuda_graphs(self):
234
+ from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
235
+
236
+ if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
237
+ self.cuda_graph_runner = None
238
+ return
239
+
240
+ logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
241
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
242
+ self.cuda_graph_runner = CudaGraphRunner(
243
+ self, max_batch_size_to_capture=max(batch_size_list)
244
+ )
245
+ self.cuda_graph_runner.capture(batch_size_list)
246
+
449
247
  @torch.inference_mode()
450
- def forward_extend(self, batch: Batch):
248
+ def forward_decode(self, batch: Batch):
249
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
250
+ return self.cuda_graph_runner.replay(batch)
251
+
451
252
  input_metadata = InputMetadata.create(
452
253
  self,
453
- forward_mode=ForwardMode.EXTEND,
454
- tp_size=self.tp_size,
254
+ forward_mode=ForwardMode.DECODE,
455
255
  req_pool_indices=batch.req_pool_indices,
456
256
  seq_lens=batch.seq_lens,
457
257
  prefix_lens=batch.prefix_lens,
@@ -459,32 +259,23 @@ class ModelRunner:
459
259
  out_cache_loc=batch.out_cache_loc,
460
260
  top_logprobs_nums=batch.top_logprobs_nums,
461
261
  return_logprob=batch.return_logprob,
462
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
463
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
464
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
465
262
  )
466
263
  return self.model.forward(
467
264
  batch.input_ids, input_metadata.positions, input_metadata
468
265
  )
469
266
 
470
267
  @torch.inference_mode()
471
- def forward_decode(self, batch: Batch):
268
+ def forward_extend(self, batch: Batch):
472
269
  input_metadata = InputMetadata.create(
473
270
  self,
474
- forward_mode=ForwardMode.DECODE,
475
- tp_size=self.tp_size,
271
+ forward_mode=ForwardMode.EXTEND,
476
272
  req_pool_indices=batch.req_pool_indices,
477
273
  seq_lens=batch.seq_lens,
478
274
  prefix_lens=batch.prefix_lens,
479
275
  position_ids_offsets=batch.position_ids_offsets,
480
276
  out_cache_loc=batch.out_cache_loc,
481
- out_cache_cont_start=batch.out_cache_cont_start,
482
- out_cache_cont_end=batch.out_cache_cont_end,
483
277
  top_logprobs_nums=batch.top_logprobs_nums,
484
278
  return_logprob=batch.return_logprob,
485
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
486
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
487
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
488
279
  )
489
280
  return self.model.forward(
490
281
  batch.input_ids, input_metadata.positions, input_metadata
@@ -495,17 +286,13 @@ class ModelRunner:
495
286
  input_metadata = InputMetadata.create(
496
287
  self,
497
288
  forward_mode=ForwardMode.EXTEND,
498
- tp_size=self.tp_size,
499
289
  req_pool_indices=batch.req_pool_indices,
500
290
  seq_lens=batch.seq_lens,
501
291
  prefix_lens=batch.prefix_lens,
502
292
  position_ids_offsets=batch.position_ids_offsets,
503
293
  out_cache_loc=batch.out_cache_loc,
504
- top_logprobs_nums=batch.top_logprobs_nums,
505
294
  return_logprob=batch.return_logprob,
506
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
507
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
508
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
295
+ top_logprobs_nums=batch.top_logprobs_nums,
509
296
  )
510
297
  return self.model.forward(
511
298
  batch.input_ids,
@@ -523,8 +310,6 @@ class ModelRunner:
523
310
  return self.forward_decode(batch)
524
311
  elif forward_mode == ForwardMode.EXTEND:
525
312
  return self.forward_extend(batch)
526
- elif forward_mode == ForwardMode.PREFILL:
527
- return self.forward_prefill(batch)
528
313
  else:
529
314
  raise ValueError(f"Invaid forward mode: {forward_mode}")
530
315
 
@@ -82,12 +82,12 @@ class RadixCache:
82
82
 
83
83
  if self.disable:
84
84
  if del_in_memory_pool:
85
- self.token_to_kv_pool.dec_refs(indices)
85
+ self.token_to_kv_pool.free(indices)
86
86
  else:
87
87
  return torch.tensor([], dtype=torch.int64), self.root_node
88
88
 
89
89
  # Radix Cache takes one ref in memory pool
90
- self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
90
+ self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
91
91
 
92
92
  if del_in_memory_pool:
93
93
  self.req_to_token_pool.free(req_pool_idx)
@@ -125,7 +125,8 @@ class RadixCache:
125
125
  if x.lock_ref > 0:
126
126
  continue
127
127
 
128
- num_evicted += evict_callback(x.value)
128
+ evict_callback(x.value)
129
+ num_evicted += len(x.value)
129
130
  self._delete_leaf(x)
130
131
 
131
132
  if len(x.parent.children) == 0:
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
13
13
  max_total_num_tokens,
14
14
  tree_cache,
15
15
  ):
16
+ if tree_cache.disable and schedule_heuristic == "lpm":
17
+ # LMP is not meaningless when tree cache is disabled.
18
+ schedule_heuristic = "fcfs"
19
+
16
20
  self.schedule_heuristic = schedule_heuristic
17
21
  self.max_running_seqs = max_running_seqs
18
22
  self.max_prefill_num_tokens = max_prefill_num_tokens