sglang 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,11 +4,9 @@ import importlib
4
4
  import importlib.resources
5
5
  import logging
6
6
  import pkgutil
7
- from dataclasses import dataclass
8
7
  from functools import lru_cache
9
- from typing import List, Optional, Type
8
+ from typing import Optional, Type
10
9
 
11
- import numpy as np
12
10
  import torch
13
11
  import torch.nn as nn
14
12
  from vllm.config import DeviceConfig, LoadConfig
@@ -17,7 +15,8 @@ from vllm.distributed import init_distributed_environment, initialize_model_para
17
15
  from vllm.model_executor.model_loader import get_model
18
16
  from vllm.model_executor.models import ModelRegistry
19
17
 
20
- from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
21
20
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
22
21
  from sglang.srt.server_args import ServerArgs
23
22
  from sglang.srt.utils import (
@@ -29,210 +28,6 @@ from sglang.srt.utils import (
29
28
 
30
29
  logger = logging.getLogger("srt.model_runner")
31
30
 
32
- # for server args in model endpoints
33
- global_server_args_dict = {}
34
-
35
-
36
- @dataclass
37
- class InputMetadata:
38
- forward_mode: ForwardMode
39
- batch_size: int
40
- total_num_tokens: int
41
- max_seq_len: int
42
- req_pool_indices: torch.Tensor
43
- start_loc: torch.Tensor
44
- seq_lens: torch.Tensor
45
- prefix_lens: torch.Tensor
46
- positions: torch.Tensor
47
- req_to_token_pool: ReqToTokenPool
48
- token_to_kv_pool: TokenToKVPool
49
-
50
- # for extend
51
- extend_seq_lens: torch.Tensor = None
52
- extend_start_loc: torch.Tensor = None
53
- max_extend_len: int = 0
54
-
55
- out_cache_loc: torch.Tensor = None
56
- out_cache_cont_start: torch.Tensor = None
57
- out_cache_cont_end: torch.Tensor = None
58
-
59
- other_kv_index: torch.Tensor = None
60
- return_logprob: bool = False
61
- top_logprobs_nums: List[int] = None
62
-
63
- # for flashinfer
64
- qo_indptr: torch.Tensor = None
65
- kv_indptr: torch.Tensor = None
66
- kv_indices: torch.Tensor = None
67
- kv_last_page_len: torch.Tensor = None
68
- flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
- flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
- flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
-
72
- def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
73
- if (
74
- self.forward_mode == ForwardMode.PREFILL
75
- or self.forward_mode == ForwardMode.EXTEND
76
- ):
77
- paged_kernel_lens = self.prefix_lens
78
- self.no_prefix = torch.all(self.prefix_lens == 0)
79
- else:
80
- paged_kernel_lens = self.seq_lens
81
-
82
- self.kv_indptr = torch.zeros(
83
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
84
- )
85
- self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
86
- self.kv_last_page_len = torch.ones(
87
- (self.batch_size,), dtype=torch.int32, device="cuda"
88
- )
89
- req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
90
- paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
91
- self.kv_indices = torch.cat(
92
- [
93
- self.req_to_token_pool.req_to_token[
94
- req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
95
- ]
96
- for i in range(self.batch_size)
97
- ],
98
- dim=0,
99
- ).contiguous()
100
-
101
- if (
102
- self.forward_mode == ForwardMode.PREFILL
103
- or self.forward_mode == ForwardMode.EXTEND
104
- ):
105
- # extend part
106
- self.qo_indptr = torch.zeros(
107
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
108
- )
109
- self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
110
-
111
- self.flashinfer_prefill_wrapper_ragged.end_forward()
112
- self.flashinfer_prefill_wrapper_ragged.begin_forward(
113
- self.qo_indptr,
114
- self.qo_indptr.clone(),
115
- num_qo_heads,
116
- num_kv_heads,
117
- head_dim,
118
- )
119
-
120
- # cached part
121
- self.flashinfer_prefill_wrapper_paged.end_forward()
122
- self.flashinfer_prefill_wrapper_paged.begin_forward(
123
- self.qo_indptr,
124
- self.kv_indptr,
125
- self.kv_indices,
126
- self.kv_last_page_len,
127
- num_qo_heads,
128
- num_kv_heads,
129
- head_dim,
130
- 1,
131
- )
132
- else:
133
- self.flashinfer_decode_wrapper.end_forward()
134
- self.flashinfer_decode_wrapper.begin_forward(
135
- self.kv_indptr,
136
- self.kv_indices,
137
- self.kv_last_page_len,
138
- num_qo_heads,
139
- num_kv_heads,
140
- head_dim,
141
- 1,
142
- pos_encoding_mode="NONE",
143
- data_type=self.token_to_kv_pool.kv_data[0].dtype,
144
- )
145
-
146
- def init_extend_args(self):
147
- self.extend_seq_lens = self.seq_lens - self.prefix_lens
148
- self.extend_start_loc = torch.zeros_like(self.seq_lens)
149
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
150
- self.max_extend_len = int(torch.max(self.extend_seq_lens))
151
-
152
- @classmethod
153
- def create(
154
- cls,
155
- model_runner,
156
- tp_size,
157
- forward_mode,
158
- req_pool_indices,
159
- seq_lens,
160
- prefix_lens,
161
- position_ids_offsets,
162
- out_cache_loc,
163
- out_cache_cont_start=None,
164
- out_cache_cont_end=None,
165
- top_logprobs_nums=None,
166
- return_logprob=False,
167
- flashinfer_prefill_wrapper_ragged=None,
168
- flashinfer_prefill_wrapper_paged=None,
169
- flashinfer_decode_wrapper=None,
170
- ):
171
- batch_size = len(req_pool_indices)
172
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
173
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
174
- total_num_tokens = int(torch.sum(seq_lens))
175
- max_seq_len = int(torch.max(seq_lens))
176
-
177
- if forward_mode == ForwardMode.DECODE:
178
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
179
- other_kv_index = model_runner.req_to_token_pool.req_to_token[
180
- req_pool_indices[0], seq_lens[0] - 1
181
- ].item()
182
- else:
183
- seq_lens_cpu = seq_lens.cpu().numpy()
184
- prefix_lens_cpu = prefix_lens.cpu().numpy()
185
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
186
- positions = torch.tensor(
187
- np.concatenate(
188
- [
189
- np.arange(
190
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
191
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
192
- )
193
- for i in range(batch_size)
194
- ],
195
- axis=0,
196
- ),
197
- device="cuda",
198
- )
199
- other_kv_index = None
200
-
201
- ret = cls(
202
- forward_mode=forward_mode,
203
- batch_size=batch_size,
204
- total_num_tokens=total_num_tokens,
205
- max_seq_len=max_seq_len,
206
- req_pool_indices=req_pool_indices,
207
- start_loc=start_loc,
208
- seq_lens=seq_lens,
209
- prefix_lens=prefix_lens,
210
- positions=positions,
211
- req_to_token_pool=model_runner.req_to_token_pool,
212
- token_to_kv_pool=model_runner.token_to_kv_pool,
213
- out_cache_loc=out_cache_loc,
214
- out_cache_cont_start=out_cache_cont_start,
215
- out_cache_cont_end=out_cache_cont_end,
216
- other_kv_index=other_kv_index,
217
- return_logprob=return_logprob,
218
- top_logprobs_nums=top_logprobs_nums,
219
- flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
220
- flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
221
- flashinfer_decode_wrapper=flashinfer_decode_wrapper,
222
- )
223
-
224
- if forward_mode == ForwardMode.EXTEND:
225
- ret.init_extend_args()
226
-
227
- if not global_server_args_dict.get("disable_flashinfer", False):
228
- ret.init_flashinfer_args(
229
- model_runner.model_config.num_attention_heads // tp_size,
230
- model_runner.model_config.get_num_kv_heads(tp_size),
231
- model_runner.model_config.head_dim,
232
- )
233
-
234
- return ret
235
-
236
31
 
237
32
  class ModelRunner:
238
33
  def __init__(
@@ -245,6 +40,7 @@ class ModelRunner:
245
40
  nccl_port: int,
246
41
  server_args: ServerArgs,
247
42
  ):
43
+ # Parse args
248
44
  self.model_config = model_config
249
45
  self.mem_fraction_static = mem_fraction_static
250
46
  self.gpu_id = gpu_id
@@ -256,7 +52,6 @@ class ModelRunner:
256
52
  monkey_patch_vllm_dummy_weight_loader()
257
53
 
258
54
  # Init torch distributed
259
- logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
260
55
  torch.cuda.set_device(self.gpu_id)
261
56
  logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
262
57
 
@@ -287,11 +82,8 @@ class ModelRunner:
287
82
  )
288
83
 
289
84
  # Set some global args
290
- global global_server_args_dict
291
- global_server_args_dict = {
292
- "disable_flashinfer": server_args.disable_flashinfer,
293
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
294
- }
85
+ global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
86
+ global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
295
87
 
296
88
  # Load the model and create memory pool
297
89
  self.load_model()
@@ -299,6 +91,9 @@ class ModelRunner:
299
91
  self.init_cublas()
300
92
  self.init_flash_infer()
301
93
 
94
+ # Capture cuda graphs
95
+ self.init_cuda_graphs()
96
+
302
97
  def load_model(self):
303
98
  logger.info(
304
99
  f"[gpu_id={self.gpu_id}] Load weight begin. "
@@ -391,67 +186,60 @@ class ModelRunner:
391
186
  return c
392
187
 
393
188
  def init_flash_infer(self):
394
- if not global_server_args_dict.get("disable_flashinfer", False):
395
- from flashinfer import (
396
- BatchDecodeWithPagedKVCacheWrapper,
397
- BatchPrefillWithPagedKVCacheWrapper,
398
- BatchPrefillWithRaggedKVCacheWrapper,
399
- )
400
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
189
+ if self.server_args.disable_flashinfer:
190
+ self.flashinfer_prefill_wrapper_ragged = None
191
+ self.flashinfer_prefill_wrapper_paged = None
192
+ self.flashinfer_decode_wrapper = None
193
+ return
401
194
 
402
- if not _grouped_size_compiled_for_decode_kernels(
403
- self.model_config.num_attention_heads // self.tp_size,
404
- self.model_config.get_num_kv_heads(self.tp_size),
405
- ):
406
- use_tensor_cores = True
407
- else:
408
- use_tensor_cores = False
195
+ from flashinfer import (
196
+ BatchDecodeWithPagedKVCacheWrapper,
197
+ BatchPrefillWithPagedKVCacheWrapper,
198
+ BatchPrefillWithRaggedKVCacheWrapper,
199
+ )
200
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
409
201
 
410
- workspace_buffers = torch.empty(
411
- 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
412
- )
413
- self.flashinfer_prefill_wrapper_ragged = (
414
- BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
415
- )
416
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
417
- workspace_buffers[1], "NHD"
418
- )
419
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
420
- workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
421
- )
202
+ if not _grouped_size_compiled_for_decode_kernels(
203
+ self.model_config.num_attention_heads // self.tp_size,
204
+ self.model_config.get_num_kv_heads(self.tp_size),
205
+ ):
206
+ use_tensor_cores = True
422
207
  else:
423
- self.flashinfer_prefill_wrapper_ragged = (
424
- self.flashinfer_prefill_wrapper_paged
425
- ) = None
426
- self.flashinfer_decode_wrapper = None
208
+ use_tensor_cores = False
427
209
 
428
- @torch.inference_mode()
429
- def forward_prefill(self, batch: Batch):
430
- input_metadata = InputMetadata.create(
431
- self,
432
- forward_mode=ForwardMode.PREFILL,
433
- tp_size=self.tp_size,
434
- req_pool_indices=batch.req_pool_indices,
435
- seq_lens=batch.seq_lens,
436
- prefix_lens=batch.prefix_lens,
437
- position_ids_offsets=batch.position_ids_offsets,
438
- out_cache_loc=batch.out_cache_loc,
439
- top_logprobs_nums=batch.top_logprobs_nums,
440
- return_logprob=batch.return_logprob,
441
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
442
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
443
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
210
+ self.flashinfer_workspace_buffers = torch.empty(
211
+ 2, global_config.flashinfer_workspace_size, dtype=torch.uint8, device="cuda"
444
212
  )
445
- return self.model.forward(
446
- batch.input_ids, input_metadata.positions, input_metadata
213
+ self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
214
+ self.flashinfer_workspace_buffers[0], "NHD"
215
+ )
216
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
217
+ self.flashinfer_workspace_buffers[1], "NHD"
447
218
  )
219
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
220
+ self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
221
+ )
222
+
223
+ def init_cuda_graphs(self):
224
+ from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
225
+
226
+ if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
227
+ self.cuda_graph_runner = None
228
+ return
229
+
230
+ logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
231
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
232
+ self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
233
+ self.cuda_graph_runner.capture(batch_size_list)
448
234
 
449
235
  @torch.inference_mode()
450
- def forward_extend(self, batch: Batch):
236
+ def forward_decode(self, batch: Batch):
237
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
238
+ return self.cuda_graph_runner.replay(batch)
239
+
451
240
  input_metadata = InputMetadata.create(
452
241
  self,
453
- forward_mode=ForwardMode.EXTEND,
454
- tp_size=self.tp_size,
242
+ forward_mode=ForwardMode.DECODE,
455
243
  req_pool_indices=batch.req_pool_indices,
456
244
  seq_lens=batch.seq_lens,
457
245
  prefix_lens=batch.prefix_lens,
@@ -459,32 +247,23 @@ class ModelRunner:
459
247
  out_cache_loc=batch.out_cache_loc,
460
248
  top_logprobs_nums=batch.top_logprobs_nums,
461
249
  return_logprob=batch.return_logprob,
462
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
463
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
464
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
465
250
  )
466
251
  return self.model.forward(
467
252
  batch.input_ids, input_metadata.positions, input_metadata
468
253
  )
469
254
 
470
255
  @torch.inference_mode()
471
- def forward_decode(self, batch: Batch):
256
+ def forward_extend(self, batch: Batch):
472
257
  input_metadata = InputMetadata.create(
473
258
  self,
474
- forward_mode=ForwardMode.DECODE,
475
- tp_size=self.tp_size,
259
+ forward_mode=ForwardMode.EXTEND,
476
260
  req_pool_indices=batch.req_pool_indices,
477
261
  seq_lens=batch.seq_lens,
478
262
  prefix_lens=batch.prefix_lens,
479
263
  position_ids_offsets=batch.position_ids_offsets,
480
264
  out_cache_loc=batch.out_cache_loc,
481
- out_cache_cont_start=batch.out_cache_cont_start,
482
- out_cache_cont_end=batch.out_cache_cont_end,
483
265
  top_logprobs_nums=batch.top_logprobs_nums,
484
266
  return_logprob=batch.return_logprob,
485
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
486
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
487
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
488
267
  )
489
268
  return self.model.forward(
490
269
  batch.input_ids, input_metadata.positions, input_metadata
@@ -495,17 +274,13 @@ class ModelRunner:
495
274
  input_metadata = InputMetadata.create(
496
275
  self,
497
276
  forward_mode=ForwardMode.EXTEND,
498
- tp_size=self.tp_size,
499
277
  req_pool_indices=batch.req_pool_indices,
500
278
  seq_lens=batch.seq_lens,
501
279
  prefix_lens=batch.prefix_lens,
502
280
  position_ids_offsets=batch.position_ids_offsets,
503
281
  out_cache_loc=batch.out_cache_loc,
504
- top_logprobs_nums=batch.top_logprobs_nums,
505
282
  return_logprob=batch.return_logprob,
506
- flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
507
- flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
508
- flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
283
+ top_logprobs_nums=batch.top_logprobs_nums,
509
284
  )
510
285
  return self.model.forward(
511
286
  batch.input_ids,
@@ -523,8 +298,6 @@ class ModelRunner:
523
298
  return self.forward_decode(batch)
524
299
  elif forward_mode == ForwardMode.EXTEND:
525
300
  return self.forward_extend(batch)
526
- elif forward_mode == ForwardMode.PREFILL:
527
- return self.forward_prefill(batch)
528
301
  else:
529
302
  raise ValueError(f"Invaid forward mode: {forward_mode}")
530
303
 
@@ -98,7 +98,7 @@ class ModelTpServer:
98
98
  )
99
99
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
100
  self.max_prefill_tokens = (
101
- 4096
101
+ 8192
102
102
  if server_args.max_prefill_tokens is None
103
103
  else server_args.max_prefill_tokens
104
104
  )
@@ -314,11 +314,9 @@ class ModelTpServer:
314
314
  self.forward_queue.append(req)
315
315
 
316
316
  def get_new_fill_batch(self) -> Optional[Batch]:
317
- if (
318
- self.running_batch is not None
319
- and len(self.running_batch.reqs) > self.max_running_requests
320
- ):
321
- return None
317
+ running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
318
+ if running_bs >= self.max_running_requests:
319
+ return
322
320
 
323
321
  # Compute matched prefix length
324
322
  for req in self.forward_queue:
@@ -394,6 +392,10 @@ class ModelTpServer:
394
392
  new_batch_input_tokens += req.extend_input_len
395
393
  else:
396
394
  break
395
+
396
+ if running_bs + len(can_run_list) >= self.max_running_requests:
397
+ break
398
+
397
399
  if len(can_run_list) == 0:
398
400
  return None
399
401
 
sglang/srt/memory_pool.py CHANGED
@@ -38,15 +38,24 @@ class ReqToTokenPool:
38
38
 
39
39
  class TokenToKVPool:
40
40
  def __init__(self, size, dtype, head_num, head_dim, layer_num):
41
- self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
41
+ self.size = size
42
+ # mem_state is the reference counter.
43
+ # We also add one slot. This slot is used for writing dummy output from padded tokens.
44
+ self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
42
45
  self.total_ref_ct = 0
43
46
 
44
47
  # [size, key/value, head_num, head_dim] for each layer
45
48
  self.kv_data = [
46
- torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
49
+ torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
47
50
  for _ in range(layer_num)
48
51
  ]
49
52
 
53
+ # Prefetch buffer
54
+ self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
55
+ self.prefetch_chunk_size = 512
56
+
57
+ self.clear()
58
+
50
59
  def get_key_buffer(self, layer_id):
51
60
  return self.kv_data[layer_id][:, 0]
52
61
 
@@ -54,14 +63,29 @@ class TokenToKVPool:
54
63
  return self.kv_data[layer_id][:, 1]
55
64
 
56
65
  def alloc(self, need_size):
57
- select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
58
- if select_index.shape[0] < need_size:
66
+ buffer_len = len(self.prefetch_buffer)
67
+ if need_size <= buffer_len:
68
+ select_index = self.prefetch_buffer[:need_size]
69
+ self.prefetch_buffer = self.prefetch_buffer[need_size:]
70
+ return select_index
71
+
72
+ addition_size = need_size - buffer_len
73
+ alloc_size = max(addition_size, self.prefetch_chunk_size)
74
+ select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
75
+
76
+ if select_index.shape[0] < addition_size:
59
77
  return None
60
78
 
61
79
  self.add_refs(select_index)
62
- return select_index.to(torch.int32)
80
+
81
+ self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
82
+ ret_index = self.prefetch_buffer[:need_size]
83
+ self.prefetch_buffer = self.prefetch_buffer[need_size:]
84
+
85
+ return ret_index
63
86
 
64
87
  def alloc_contiguous(self, need_size):
88
+ # NOTE: This function is deprecated.
65
89
  empty_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:need_size]
66
90
  if empty_index.shape[0] < need_size:
67
91
  return None
@@ -84,7 +108,7 @@ class TokenToKVPool:
84
108
  return len(torch.nonzero(self.mem_state).squeeze(1))
85
109
 
86
110
  def available_size(self):
87
- return torch.sum(self.mem_state == 0).item()
111
+ return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
88
112
 
89
113
  def add_refs(self, token_index: torch.Tensor):
90
114
  self.total_ref_ct += len(token_index)
@@ -101,3 +125,6 @@ class TokenToKVPool:
101
125
  def clear(self):
102
126
  self.mem_state.fill_(0)
103
127
  self.total_ref_ct = 0
128
+
129
+ # We also add one slot. This slot is used for writing dummy output from padded tokens.
130
+ self.add_refs(torch.tensor([0], dtype=torch.int32))
sglang/srt/server.py CHANGED
@@ -146,6 +146,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
146
146
 
147
147
  # Set global environments
148
148
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
149
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
149
150
  if server_args.show_time_cost:
150
151
  enable_show_time_cost()
151
152
  if server_args.disable_disk_cache:
sglang/srt/server_args.py CHANGED
@@ -29,7 +29,7 @@ class ServerArgs:
29
29
  max_prefill_tokens: Optional[int] = None
30
30
  max_running_requests: Optional[int] = None
31
31
  schedule_heuristic: str = "lpm"
32
- schedule_conservativeness: float = 1.0
32
+ schedule_conservativeness: float = 0.8
33
33
 
34
34
  # Other runtime options
35
35
  tp_size: int = 1
@@ -53,6 +53,7 @@ class ServerArgs:
53
53
  disable_flashinfer: bool = False
54
54
  disable_radix_cache: bool = False
55
55
  disable_regex_jump_forward: bool = False
56
+ disable_cuda_graph: bool = False
56
57
  disable_disk_cache: bool = False
57
58
  attention_reduce_in_fp32: bool = False
58
59
  enable_p2p_check: bool = False
@@ -67,13 +68,13 @@ class ServerArgs:
67
68
  self.tokenizer_path = self.model_path
68
69
  if self.mem_fraction_static is None:
69
70
  if self.tp_size >= 8:
70
- self.mem_fraction_static = 0.80
71
+ self.mem_fraction_static = 0.78
71
72
  elif self.tp_size >= 4:
72
- self.mem_fraction_static = 0.82
73
+ self.mem_fraction_static = 0.80
73
74
  elif self.tp_size >= 2:
74
75
  self.mem_fraction_static = 0.85
75
76
  else:
76
- self.mem_fraction_static = 0.90
77
+ self.mem_fraction_static = 0.88
77
78
  if isinstance(self.additional_ports, int):
78
79
  self.additional_ports = [self.additional_ports]
79
80
  elif self.additional_ports is None:
@@ -294,6 +295,11 @@ class ServerArgs:
294
295
  action="store_true",
295
296
  help="Disable regex jump-forward",
296
297
  )
298
+ parser.add_argument(
299
+ "--disable-cuda-graph",
300
+ action="store_true",
301
+ help="Disable cuda graph.",
302
+ )
297
303
  parser.add_argument(
298
304
  "--disable-disk-cache",
299
305
  action="store_true",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.19
3
+ Version: 0.1.20
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004