sglang 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,11 +1,13 @@
1
1
  """Run the model with cuda graph."""
2
2
 
3
3
  import bisect
4
+ from contextlib import contextmanager
4
5
 
5
6
  import torch
6
7
  from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7
8
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
8
9
  from vllm.distributed.parallel_state import graph_capture
10
+ from vllm.model_executor.custom_op import CustomOp
9
11
 
10
12
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
11
13
  from sglang.srt.managers.controller.infer_batch import (
@@ -14,10 +16,44 @@ from sglang.srt.managers.controller.infer_batch import (
14
16
  InputMetadata,
15
17
  init_flashinfer_args,
16
18
  )
19
+ from sglang.srt.utils import monkey_patch_vllm_all_gather
20
+
21
+
22
+ def _to_torch(model: torch.nn.Module, reverse: bool = False):
23
+ for sub in model._modules.values():
24
+ if isinstance(sub, CustomOp):
25
+ if reverse:
26
+ sub._forward_method = sub.forward_cuda
27
+ else:
28
+ sub._forward_method = sub.forward_native
29
+ if isinstance(sub, torch.nn.Module):
30
+ _to_torch(sub, reverse)
31
+
32
+
33
+ @contextmanager
34
+ def patch_model(
35
+ model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
36
+ ):
37
+ backup_ca_comm = None
38
+
39
+ try:
40
+ if use_compile:
41
+ _to_torch(model)
42
+ monkey_patch_vllm_all_gather()
43
+ backup_ca_comm = tp_group.ca_comm
44
+ tp_group.ca_comm = None
45
+ yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
46
+ else:
47
+ yield model.forward
48
+ finally:
49
+ if use_compile:
50
+ _to_torch(model, reverse=True)
51
+ monkey_patch_vllm_all_gather(reverse=True)
52
+ tp_group.ca_comm = backup_ca_comm
17
53
 
18
54
 
19
55
  class CudaGraphRunner:
20
- def __init__(self, model_runner, max_batch_size_to_capture):
56
+ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
21
57
  self.model_runner = model_runner
22
58
  self.graphs = {}
23
59
  self.input_buffers = {}
@@ -55,6 +91,8 @@ class CudaGraphRunner:
55
91
  (self.max_bs,), dtype=torch.int32, device="cuda"
56
92
  )
57
93
 
94
+ self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
95
+
58
96
  def can_run(self, batch_size):
59
97
  return batch_size < self.max_bs
60
98
 
@@ -63,18 +101,23 @@ class CudaGraphRunner:
63
101
  with graph_capture() as graph_capture_context:
64
102
  self.stream = graph_capture_context.stream
65
103
  for bs in batch_size_list:
66
- (
67
- graph,
68
- input_buffers,
69
- output_buffers,
70
- flashinfer_handler,
71
- ) = self.capture_one_batch_size(bs)
72
- self.graphs[bs] = graph
73
- self.input_buffers[bs] = input_buffers
74
- self.output_buffers[bs] = output_buffers
75
- self.flashinfer_handlers[bs] = flashinfer_handler
76
-
77
- def capture_one_batch_size(self, bs):
104
+ with patch_model(
105
+ self.model_runner.model,
106
+ bs in self.compile_bs,
107
+ self.model_runner.tp_group,
108
+ ) as forward:
109
+ (
110
+ graph,
111
+ input_buffers,
112
+ output_buffers,
113
+ flashinfer_handler,
114
+ ) = self.capture_one_batch_size(bs, forward)
115
+ self.graphs[bs] = graph
116
+ self.input_buffers[bs] = input_buffers
117
+ self.output_buffers[bs] = output_buffers
118
+ self.flashinfer_handlers[bs] = flashinfer_handler
119
+
120
+ def capture_one_batch_size(self, bs, forward):
78
121
  graph = torch.cuda.CUDAGraph()
79
122
  stream = self.stream
80
123
 
@@ -127,9 +170,8 @@ class CudaGraphRunner:
127
170
  skip_flashinfer_init=True,
128
171
  )
129
172
  input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
130
- return self.model_runner.model.forward(
131
- input_ids, input_metadata.positions, input_metadata
132
- )
173
+
174
+ return forward(input_ids, input_metadata.positions, input_metadata)
133
175
 
134
176
  for _ in range(2):
135
177
  run_once()
@@ -9,6 +9,7 @@ import numpy as np
9
9
  import torch
10
10
  from flashinfer.sampling import top_k_top_p_sampling_from_probs
11
11
 
12
+ from sglang.global_config import global_config
12
13
  from sglang.srt.constrained import RegexGuide
13
14
  from sglang.srt.constrained.jump_forward import JumpForwardMap
14
15
  from sglang.srt.managers.controller.radix_cache import RadixCache
@@ -431,7 +432,8 @@ class Batch:
431
432
 
432
433
  def retract_decode(self):
433
434
  sorted_indices = [i for i in range(len(self.reqs))]
434
- # TODO(lsyin): improve the priority of retraction
435
+
436
+ # TODO(lsyin): improve retraction policy for radix cache
435
437
  sorted_indices.sort(
436
438
  key=lambda i: (
437
439
  len(self.reqs[i].output_ids),
@@ -443,7 +445,17 @@ class Batch:
443
445
  retracted_reqs = []
444
446
  seq_lens_cpu = self.seq_lens.cpu().numpy()
445
447
  req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
446
- while self.token_to_kv_pool.available_size() < len(self.reqs):
448
+ while (
449
+ self.token_to_kv_pool.available_size()
450
+ < len(sorted_indices) * global_config.retract_decode_steps
451
+ ):
452
+ if len(sorted_indices) == 1:
453
+ # Corner case: only one request left
454
+ assert (
455
+ self.token_to_kv_pool.available_size() > 0
456
+ ), "No space left for only one request"
457
+ break
458
+
447
459
  idx = sorted_indices.pop()
448
460
  req = self.reqs[idx]
449
461
  retracted_reqs.append(req)
@@ -468,7 +480,16 @@ class Batch:
468
480
 
469
481
  self.filter_batch(sorted_indices)
470
482
 
471
- return retracted_reqs
483
+ # Reqs in batch are filtered
484
+ total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
485
+ total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
486
+
487
+ new_estimate_ratio = (
488
+ total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
489
+ ) / total_max_new_tokens
490
+ new_estimate_ratio = min(1.0, new_estimate_ratio)
491
+
492
+ return retracted_reqs, new_estimate_ratio
472
493
 
473
494
  def check_for_jump_forward(self, model_runner):
474
495
  jump_forward_reqs = []
@@ -668,18 +689,17 @@ class Batch:
668
689
 
669
690
  max_top_k_round, batch_size = 32, probs.shape[0]
670
691
  uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
671
- batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
692
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
672
693
  probs, uniform_samples, self.top_ks, self.top_ps
673
694
  )
674
695
 
675
- # FIXME: this is a temporary fix for the illegal token ids
676
- illegal_mask = torch.logical_or(
677
- batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
678
- )
679
- if torch.any(illegal_mask):
680
- warnings.warn("Illegal sampled token ids")
696
+ if torch.any(~success):
697
+ warnings.warn("Sampling failed, fallback to top_k=1 strategy")
681
698
  probs = probs.masked_fill(torch.isnan(probs), 0.0)
682
- batch_next_token_ids = torch.argmax(probs, dim=-1)
699
+ argmax_ids = torch.argmax(probs, dim=-1)
700
+ batch_next_token_ids = torch.where(
701
+ success, batch_next_token_ids, argmax_ids
702
+ )
683
703
 
684
704
  if has_regex:
685
705
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
@@ -727,6 +747,7 @@ class InputMetadata:
727
747
  flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
728
748
  flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
729
749
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
750
+ use_ragged: bool = False
730
751
 
731
752
  @classmethod
732
753
  def create(
@@ -742,7 +763,10 @@ class InputMetadata:
742
763
  return_logprob=False,
743
764
  skip_flashinfer_init=False,
744
765
  ):
766
+ use_ragged = False
745
767
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
768
+ if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
769
+ use_ragged = True
746
770
  init_flashinfer_args(
747
771
  forward_mode,
748
772
  model_runner,
@@ -750,6 +774,7 @@ class InputMetadata:
750
774
  seq_lens,
751
775
  prefix_lens,
752
776
  model_runner.flashinfer_decode_wrapper,
777
+ use_ragged,
753
778
  )
754
779
 
755
780
  batch_size = len(req_pool_indices)
@@ -804,6 +829,7 @@ class InputMetadata:
804
829
  flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
805
830
  flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
806
831
  flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
832
+ use_ragged=use_ragged,
807
833
  )
808
834
 
809
835
  if model_runner.server_args.disable_flashinfer:
@@ -824,17 +850,19 @@ def init_flashinfer_args(
824
850
  seq_lens,
825
851
  prefix_lens,
826
852
  flashinfer_decode_wrapper,
853
+ use_ragged=False,
827
854
  ):
828
855
  """Init auxiliary variables for FlashInfer attention backend."""
829
856
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
830
857
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
831
858
  head_dim = model_runner.model_config.head_dim
832
859
  batch_size = len(req_pool_indices)
860
+ total_num_tokens = int(torch.sum(seq_lens))
833
861
 
834
- if forward_mode == ForwardMode.DECODE:
835
- paged_kernel_lens = seq_lens
836
- else:
862
+ if use_ragged:
837
863
  paged_kernel_lens = prefix_lens
864
+ else:
865
+ paged_kernel_lens = seq_lens
838
866
 
839
867
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
840
868
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
@@ -867,14 +895,15 @@ def init_flashinfer_args(
867
895
  qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
868
896
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
869
897
 
870
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
871
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
872
- qo_indptr,
873
- qo_indptr,
874
- num_qo_heads,
875
- num_kv_heads,
876
- head_dim,
877
- )
898
+ if use_ragged:
899
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
900
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
901
+ qo_indptr,
902
+ qo_indptr,
903
+ num_qo_heads,
904
+ num_kv_heads,
905
+ head_dim,
906
+ )
878
907
 
879
908
  # cached part
880
909
  model_runner.flashinfer_prefill_wrapper_paged.end_forward()
@@ -22,7 +22,6 @@ from vllm.distributed import (
22
22
  init_distributed_environment,
23
23
  initialize_model_parallel,
24
24
  )
25
- from vllm.model_executor.model_loader import get_model
26
25
  from vllm.model_executor.models import ModelRegistry
27
26
 
28
27
  from sglang.global_config import global_config
@@ -241,16 +240,20 @@ class ModelRunner:
241
240
  self.cuda_graph_runner = None
242
241
  return
243
242
 
244
- logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
243
+ logger.info(
244
+ f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
245
+ )
245
246
  batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
246
247
  self.cuda_graph_runner = CudaGraphRunner(
247
- self, max_batch_size_to_capture=max(batch_size_list)
248
+ self,
249
+ max_batch_size_to_capture=max(batch_size_list),
250
+ use_torch_compile=self.server_args.enable_torch_compile,
248
251
  )
249
252
  try:
250
253
  self.cuda_graph_runner.capture(batch_size_list)
251
254
  except RuntimeError as e:
252
255
  raise Exception(
253
- f"Capture cuda graph failed {e}. Possible solutions:\n"
256
+ f"Capture cuda graph failed: {e}. Possible solutions:\n"
254
257
  f"1. disable cuda graph by --disable-cuda-graph\n"
255
258
  f"2. set --mem-fraction-static to a smaller value\n"
256
259
  f"Open an issue on GitHub with reproducible scripts if you need help.\n"
@@ -14,7 +14,7 @@ class ScheduleHeuristic:
14
14
  tree_cache,
15
15
  ):
16
16
  if tree_cache.disable and schedule_heuristic == "lpm":
17
- # LMP is not meaningless when tree cache is disabled.
17
+ # LMP is meaningless when the tree cache is disabled.
18
18
  schedule_heuristic = "fcfs"
19
19
 
20
20
  self.schedule_heuristic = schedule_heuristic
@@ -28,11 +28,16 @@ class ScheduleHeuristic:
28
28
  # longest prefix match
29
29
  forward_queue.sort(key=lambda x: -len(x.prefix_indices))
30
30
  return forward_queue
31
+ elif self.schedule_heuristic == "fcfs":
32
+ # first come first serve
33
+ return forward_queue
34
+ elif self.schedule_heuristic == "lof":
35
+ # longest output first
36
+ forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
37
+ return forward_queue
31
38
  elif self.schedule_heuristic == "random":
32
39
  random.shuffle(forward_queue)
33
40
  return forward_queue
34
- elif self.schedule_heuristic == "fcfs":
35
- return forward_queue
36
41
  elif self.schedule_heuristic == "dfs-weight":
37
42
  last_node_to_reqs = defaultdict(list)
38
43
  for req in forward_queue:
@@ -103,6 +103,9 @@ class ModelTpServer:
103
103
  if server_args.max_running_requests is None
104
104
  else server_args.max_running_requests
105
105
  )
106
+ self.max_running_requests = min(
107
+ self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
108
+ )
106
109
  self.int_token_logit_bias = torch.tensor(
107
110
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
108
111
  )
@@ -113,13 +116,9 @@ class ModelTpServer:
113
116
  f"[gpu_id={self.gpu_id}] "
114
117
  f"max_total_num_tokens={self.max_total_num_tokens}, "
115
118
  f"max_prefill_tokens={self.max_prefill_tokens}, "
119
+ f"max_running_requests={self.max_running_requests}, "
116
120
  f"context_len={self.model_config.context_len}"
117
121
  )
118
- if self.tp_rank == 0:
119
- logger.info(
120
- f"[gpu_id={self.gpu_id}] "
121
- f"server_args: {server_args.print_mode_args()}"
122
- )
123
122
 
124
123
  # Init cache
125
124
  self.tree_cache = RadixCache(
@@ -161,15 +160,12 @@ class ModelTpServer:
161
160
  assert (
162
161
  server_args.schedule_conservativeness >= 0
163
162
  ), "Invalid schedule_conservativeness"
164
- self.new_token_ratio = min(
165
- global_config.base_new_token_ratio * server_args.schedule_conservativeness,
166
- 1.0,
167
- )
168
163
  self.min_new_token_ratio = min(
169
164
  global_config.base_min_new_token_ratio
170
165
  * server_args.schedule_conservativeness,
171
166
  1.0,
172
167
  )
168
+ self.new_token_ratio = self.min_new_token_ratio
173
169
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
174
170
  self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
175
171
 
@@ -231,6 +227,7 @@ class ModelTpServer:
231
227
  break
232
228
  else:
233
229
  self.check_memory()
230
+ self.new_token_ratio = global_config.init_new_token_ratio
234
231
 
235
232
  def print_stats(self):
236
233
  num_used = self.max_total_num_tokens - (
@@ -539,9 +536,10 @@ class ModelTpServer:
539
536
  # Check if decode out of memory
540
537
  if not batch.check_decode_mem():
541
538
  old_ratio = self.new_token_ratio
542
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
543
539
 
544
- retracted_reqs = batch.retract_decode()
540
+ retracted_reqs, new_token_ratio = batch.retract_decode()
541
+ self.new_token_ratio = new_token_ratio
542
+
545
543
  logger.info(
546
544
  "decode out of memory happened, "
547
545
  f"#retracted_reqs: {len(retracted_reqs)}, "
sglang/srt/memory_pool.py CHANGED
@@ -11,6 +11,7 @@ class ReqToTokenPool:
11
11
  """A memory pool that maps a request to its token locations."""
12
12
 
13
13
  def __init__(self, size: int, max_context_len: int):
14
+ self.size = size
14
15
  self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
15
16
  self.req_to_token = torch.empty(
16
17
  (size, max_context_len), dtype=torch.int32, device="cuda"
@@ -57,9 +58,13 @@ class TokenToKVPool:
57
58
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
58
59
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
59
60
 
60
- # [size, key/value, head_num, head_dim] for each layer
61
- self.kv_data = [
62
- torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
61
+ # [size, head_num, head_dim] for each layer
62
+ self.k_buffer = [
63
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
64
+ for _ in range(layer_num)
65
+ ]
66
+ self.v_buffer = [
67
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
63
68
  for _ in range(layer_num)
64
69
  ]
65
70
 
@@ -71,10 +76,13 @@ class TokenToKVPool:
71
76
  self.clear()
72
77
 
73
78
  def get_key_buffer(self, layer_id: int):
74
- return self.kv_data[layer_id][:, 0]
79
+ return self.k_buffer[layer_id]
75
80
 
76
81
  def get_value_buffer(self, layer_id: int):
77
- return self.kv_data[layer_id][:, 1]
82
+ return self.v_buffer[layer_id]
83
+
84
+ def get_kv_buffer(self, layer_id: int):
85
+ return self.k_buffer[layer_id], self.v_buffer[layer_id]
78
86
 
79
87
  def available_size(self):
80
88
  return self.can_use_mem_size + len(self.prefetch_buffer)