sglang 0.1.22__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/bench_serving.py +243 -25
  3. sglang/global_config.py +3 -2
  4. sglang/lang/interpreter.py +1 -0
  5. sglang/srt/hf_transformers_utils.py +13 -1
  6. sglang/srt/layers/logits_processor.py +4 -5
  7. sglang/srt/layers/radix_attention.py +38 -49
  8. sglang/srt/managers/controller/cuda_graph_runner.py +58 -16
  9. sglang/srt/managers/controller/infer_batch.py +51 -22
  10. sglang/srt/managers/controller/model_runner.py +58 -4
  11. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  12. sglang/srt/managers/controller/tp_worker.py +9 -11
  13. sglang/srt/memory_pool.py +13 -5
  14. sglang/srt/models/deepseek.py +430 -0
  15. sglang/srt/models/gpt_bigcode.py +282 -0
  16. sglang/srt/models/llama2.py +19 -10
  17. sglang/srt/server.py +26 -1
  18. sglang/srt/server_args.py +12 -6
  19. sglang/srt/utils.py +93 -1
  20. sglang/version.py +1 -0
  21. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/METADATA +10 -6
  22. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/RECORD +25 -36
  23. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/WHEEL +1 -1
  24. sglang/backend/__init__.py +0 -0
  25. sglang/backend/anthropic.py +0 -77
  26. sglang/backend/base_backend.py +0 -80
  27. sglang/backend/litellm.py +0 -90
  28. sglang/backend/openai.py +0 -438
  29. sglang/backend/runtime_endpoint.py +0 -283
  30. sglang/backend/vertexai.py +0 -149
  31. sglang/bench.py +0 -627
  32. sglang/srt/managers/controller/dp_worker.py +0 -113
  33. sglang/srt/openai_api/api_adapter.py +0 -432
  34. sglang/srt/openai_api/openai_api_adapter.py +0 -431
  35. sglang/srt/openai_api/openai_protocol.py +0 -207
  36. sglang/srt/openai_api_adapter.py +0 -411
  37. sglang/srt/openai_protocol.py +0 -207
  38. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/LICENSE +0 -0
  39. {sglang-0.1.22.dist-info → sglang-0.1.25.dist-info}/top_level.txt +0 -0
@@ -85,32 +85,47 @@ class RadixAttention(nn.Module):
85
85
  return o
86
86
 
87
87
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
88
- o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
89
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
90
- k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
91
- v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
92
- causal=True,
93
- sm_scale=self.scaling,
94
- logits_soft_cap=self.logit_cap,
95
- )
88
+ if not input_metadata.use_ragged:
89
+ self.store_kv_cache(k, v, input_metadata)
96
90
 
97
- if input_metadata.extend_no_prefix:
98
- o = o1
99
- else:
100
- o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
91
+ o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
101
92
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
103
- causal=False,
93
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
94
+ causal=True,
104
95
  sm_scale=self.scaling,
105
96
  logits_soft_cap=self.logit_cap,
106
97
  )
98
+ else:
99
+ o1, s1 = (
100
+ input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
101
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
102
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
103
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
104
+ causal=True,
105
+ sm_scale=self.scaling,
106
+ logits_soft_cap=self.logit_cap,
107
+ )
108
+ )
107
109
 
108
- o, _ = merge_state(o1, s1, o2, s2)
110
+ if input_metadata.extend_no_prefix:
111
+ o = o1
112
+ else:
113
+ o2, s2 = (
114
+ input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
115
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
116
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
117
+ causal=False,
118
+ sm_scale=self.scaling,
119
+ logits_soft_cap=self.logit_cap,
120
+ )
121
+ )
109
122
 
110
- self.store_kv_cache(k, v, input_metadata)
123
+ o, _ = merge_state(o1, s1, o2, s2)
124
+
125
+ self.store_kv_cache(k, v, input_metadata)
111
126
 
112
- if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
113
- torch.cuda.synchronize()
127
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
+ torch.cuda.synchronize()
114
129
 
115
130
  return o.view(-1, self.tp_q_head_num * self.head_dim)
116
131
 
@@ -119,7 +134,7 @@ class RadixAttention(nn.Module):
119
134
 
120
135
  o = input_metadata.flashinfer_decode_wrapper.forward(
121
136
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
122
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
137
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
123
138
  sm_scale=self.scaling,
124
139
  logits_soft_cap=self.logit_cap,
125
140
  )
@@ -136,33 +151,7 @@ class RadixAttention(nn.Module):
136
151
  return self.decode_forward(q, k, v, input_metadata)
137
152
 
138
153
  def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
139
- kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
140
- _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
141
-
142
-
143
- try:
144
-
145
- @torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
146
- def _store_kv_cache(
147
- k: torch.Tensor,
148
- v: torch.Tensor,
149
- kv_cache: torch.Tensor,
150
- cache_loc: torch.Tensor,
151
- ) -> None:
152
- kv_cache[cache_loc, 0] = k
153
- kv_cache[cache_loc, 1] = v
154
-
155
- @_store_kv_cache.register_fake
156
- def _(k, v, kv_cache, cache_loc):
157
- pass
158
-
159
- except:
160
-
161
- def _store_kv_cache(
162
- k: torch.Tensor,
163
- v: torch.Tensor,
164
- kv_cache: torch.Tensor,
165
- cache_loc: torch.Tensor,
166
- ) -> None:
167
- kv_cache[cache_loc, 0] = k
168
- kv_cache[cache_loc, 1] = v
154
+ k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
155
+ v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
156
+ k_cache[input_metadata.out_cache_loc] = cache_k
157
+ v_cache[input_metadata.out_cache_loc] = cache_v
@@ -1,11 +1,13 @@
1
1
  """Run the model with cuda graph."""
2
2
 
3
3
  import bisect
4
+ from contextlib import contextmanager
4
5
 
5
6
  import torch
6
7
  from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7
8
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
8
9
  from vllm.distributed.parallel_state import graph_capture
10
+ from vllm.model_executor.custom_op import CustomOp
9
11
 
10
12
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
11
13
  from sglang.srt.managers.controller.infer_batch import (
@@ -14,10 +16,44 @@ from sglang.srt.managers.controller.infer_batch import (
14
16
  InputMetadata,
15
17
  init_flashinfer_args,
16
18
  )
19
+ from sglang.srt.utils import monkey_patch_vllm_all_gather
20
+
21
+
22
+ def _to_torch(model: torch.nn.Module, reverse: bool = False):
23
+ for sub in model._modules.values():
24
+ if isinstance(sub, CustomOp):
25
+ if reverse:
26
+ sub._forward_method = sub.forward_cuda
27
+ else:
28
+ sub._forward_method = sub.forward_native
29
+ if isinstance(sub, torch.nn.Module):
30
+ _to_torch(sub, reverse)
31
+
32
+
33
+ @contextmanager
34
+ def patch_model(
35
+ model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
36
+ ):
37
+ backup_ca_comm = None
38
+
39
+ try:
40
+ if use_compile:
41
+ _to_torch(model)
42
+ monkey_patch_vllm_all_gather()
43
+ backup_ca_comm = tp_group.ca_comm
44
+ tp_group.ca_comm = None
45
+ yield torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
46
+ else:
47
+ yield model.forward
48
+ finally:
49
+ if use_compile:
50
+ _to_torch(model, reverse=True)
51
+ monkey_patch_vllm_all_gather(reverse=True)
52
+ tp_group.ca_comm = backup_ca_comm
17
53
 
18
54
 
19
55
  class CudaGraphRunner:
20
- def __init__(self, model_runner, max_batch_size_to_capture):
56
+ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
21
57
  self.model_runner = model_runner
22
58
  self.graphs = {}
23
59
  self.input_buffers = {}
@@ -55,6 +91,8 @@ class CudaGraphRunner:
55
91
  (self.max_bs,), dtype=torch.int32, device="cuda"
56
92
  )
57
93
 
94
+ self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
95
+
58
96
  def can_run(self, batch_size):
59
97
  return batch_size < self.max_bs
60
98
 
@@ -63,18 +101,23 @@ class CudaGraphRunner:
63
101
  with graph_capture() as graph_capture_context:
64
102
  self.stream = graph_capture_context.stream
65
103
  for bs in batch_size_list:
66
- (
67
- graph,
68
- input_buffers,
69
- output_buffers,
70
- flashinfer_handler,
71
- ) = self.capture_one_batch_size(bs)
72
- self.graphs[bs] = graph
73
- self.input_buffers[bs] = input_buffers
74
- self.output_buffers[bs] = output_buffers
75
- self.flashinfer_handlers[bs] = flashinfer_handler
76
-
77
- def capture_one_batch_size(self, bs):
104
+ with patch_model(
105
+ self.model_runner.model,
106
+ bs in self.compile_bs,
107
+ self.model_runner.tp_group,
108
+ ) as forward:
109
+ (
110
+ graph,
111
+ input_buffers,
112
+ output_buffers,
113
+ flashinfer_handler,
114
+ ) = self.capture_one_batch_size(bs, forward)
115
+ self.graphs[bs] = graph
116
+ self.input_buffers[bs] = input_buffers
117
+ self.output_buffers[bs] = output_buffers
118
+ self.flashinfer_handlers[bs] = flashinfer_handler
119
+
120
+ def capture_one_batch_size(self, bs, forward):
78
121
  graph = torch.cuda.CUDAGraph()
79
122
  stream = self.stream
80
123
 
@@ -127,9 +170,8 @@ class CudaGraphRunner:
127
170
  skip_flashinfer_init=True,
128
171
  )
129
172
  input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
130
- return self.model_runner.model.forward(
131
- input_ids, input_metadata.positions, input_metadata
132
- )
173
+
174
+ return forward(input_ids, input_metadata.positions, input_metadata)
133
175
 
134
176
  for _ in range(2):
135
177
  run_once()
@@ -9,6 +9,7 @@ import numpy as np
9
9
  import torch
10
10
  from flashinfer.sampling import top_k_top_p_sampling_from_probs
11
11
 
12
+ from sglang.global_config import global_config
12
13
  from sglang.srt.constrained import RegexGuide
13
14
  from sglang.srt.constrained.jump_forward import JumpForwardMap
14
15
  from sglang.srt.managers.controller.radix_cache import RadixCache
@@ -431,7 +432,8 @@ class Batch:
431
432
 
432
433
  def retract_decode(self):
433
434
  sorted_indices = [i for i in range(len(self.reqs))]
434
- # TODO(lsyin): improve the priority of retraction
435
+
436
+ # TODO(lsyin): improve retraction policy for radix cache
435
437
  sorted_indices.sort(
436
438
  key=lambda i: (
437
439
  len(self.reqs[i].output_ids),
@@ -443,7 +445,17 @@ class Batch:
443
445
  retracted_reqs = []
444
446
  seq_lens_cpu = self.seq_lens.cpu().numpy()
445
447
  req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
446
- while self.token_to_kv_pool.available_size() < len(self.reqs):
448
+ while (
449
+ self.token_to_kv_pool.available_size()
450
+ < len(sorted_indices) * global_config.retract_decode_steps
451
+ ):
452
+ if len(sorted_indices) == 1:
453
+ # Corner case: only one request left
454
+ assert (
455
+ self.token_to_kv_pool.available_size() > 0
456
+ ), "No space left for only one request"
457
+ break
458
+
447
459
  idx = sorted_indices.pop()
448
460
  req = self.reqs[idx]
449
461
  retracted_reqs.append(req)
@@ -468,7 +480,16 @@ class Batch:
468
480
 
469
481
  self.filter_batch(sorted_indices)
470
482
 
471
- return retracted_reqs
483
+ # Reqs in batch are filtered
484
+ total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
485
+ total_max_new_tokens = sum(r.sampling_params.max_new_tokens for r in self.reqs)
486
+
487
+ new_estimate_ratio = (
488
+ total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
489
+ ) / total_max_new_tokens
490
+ new_estimate_ratio = min(1.0, new_estimate_ratio)
491
+
492
+ return retracted_reqs, new_estimate_ratio
472
493
 
473
494
  def check_for_jump_forward(self, model_runner):
474
495
  jump_forward_reqs = []
@@ -668,18 +689,17 @@ class Batch:
668
689
 
669
690
  max_top_k_round, batch_size = 32, probs.shape[0]
670
691
  uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device)
671
- batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
692
+ batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
672
693
  probs, uniform_samples, self.top_ks, self.top_ps
673
694
  )
674
695
 
675
- # FIXME: this is a temporary fix for the illegal token ids
676
- illegal_mask = torch.logical_or(
677
- batch_next_token_ids < 0, batch_next_token_ids >= probs.shape[-1]
678
- )
679
- if torch.any(illegal_mask):
680
- warnings.warn("Illegal sampled token ids")
696
+ if torch.any(~success):
697
+ warnings.warn("Sampling failed, fallback to top_k=1 strategy")
681
698
  probs = probs.masked_fill(torch.isnan(probs), 0.0)
682
- batch_next_token_ids = torch.argmax(probs, dim=-1)
699
+ argmax_ids = torch.argmax(probs, dim=-1)
700
+ batch_next_token_ids = torch.where(
701
+ success, batch_next_token_ids, argmax_ids
702
+ )
683
703
 
684
704
  if has_regex:
685
705
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
@@ -727,6 +747,7 @@ class InputMetadata:
727
747
  flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
728
748
  flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
729
749
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
750
+ use_ragged: bool = False
730
751
 
731
752
  @classmethod
732
753
  def create(
@@ -742,7 +763,10 @@ class InputMetadata:
742
763
  return_logprob=False,
743
764
  skip_flashinfer_init=False,
744
765
  ):
766
+ use_ragged = False
745
767
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
768
+ if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
769
+ use_ragged = True
746
770
  init_flashinfer_args(
747
771
  forward_mode,
748
772
  model_runner,
@@ -750,6 +774,7 @@ class InputMetadata:
750
774
  seq_lens,
751
775
  prefix_lens,
752
776
  model_runner.flashinfer_decode_wrapper,
777
+ use_ragged,
753
778
  )
754
779
 
755
780
  batch_size = len(req_pool_indices)
@@ -804,6 +829,7 @@ class InputMetadata:
804
829
  flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
805
830
  flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
806
831
  flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
832
+ use_ragged=use_ragged,
807
833
  )
808
834
 
809
835
  if model_runner.server_args.disable_flashinfer:
@@ -824,17 +850,19 @@ def init_flashinfer_args(
824
850
  seq_lens,
825
851
  prefix_lens,
826
852
  flashinfer_decode_wrapper,
853
+ use_ragged=False,
827
854
  ):
828
855
  """Init auxiliary variables for FlashInfer attention backend."""
829
856
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
830
857
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
831
858
  head_dim = model_runner.model_config.head_dim
832
859
  batch_size = len(req_pool_indices)
860
+ total_num_tokens = int(torch.sum(seq_lens))
833
861
 
834
- if forward_mode == ForwardMode.DECODE:
835
- paged_kernel_lens = seq_lens
836
- else:
862
+ if use_ragged:
837
863
  paged_kernel_lens = prefix_lens
864
+ else:
865
+ paged_kernel_lens = seq_lens
838
866
 
839
867
  kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
840
868
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
@@ -867,14 +895,15 @@ def init_flashinfer_args(
867
895
  qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
868
896
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
869
897
 
870
- model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
871
- model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
872
- qo_indptr,
873
- qo_indptr,
874
- num_qo_heads,
875
- num_kv_heads,
876
- head_dim,
877
- )
898
+ if use_ragged:
899
+ model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
900
+ model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
901
+ qo_indptr,
902
+ qo_indptr,
903
+ num_qo_heads,
904
+ num_kv_heads,
905
+ head_dim,
906
+ )
878
907
 
879
908
  # cached part
880
909
  model_runner.flashinfer_prefill_wrapper_paged.end_forward()
@@ -15,6 +15,7 @@ from flashinfer import (
15
15
  BatchPrefillWithRaggedKVCacheWrapper,
16
16
  )
17
17
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
18
+ from torch.nn.parameter import Parameter
18
19
  from vllm.config import DeviceConfig, LoadConfig
19
20
  from vllm.config import ModelConfig as VllmModelConfig
20
21
  from vllm.distributed import (
@@ -22,7 +23,7 @@ from vllm.distributed import (
22
23
  init_distributed_environment,
23
24
  initialize_model_parallel,
24
25
  )
25
- from vllm.model_executor.model_loader import get_model
26
+ from vllm.model_executor.layers.linear import QKVParallelLinear
26
27
  from vllm.model_executor.models import ModelRegistry
27
28
 
28
29
  from sglang.global_config import global_config
@@ -39,6 +40,18 @@ from sglang.srt.utils import (
39
40
  logger = logging.getLogger("srt.model_runner")
40
41
 
41
42
 
43
+ def is_llama3_405b_fp8(model_config):
44
+ if (
45
+ model_config.hf_config.architectures[0] == "LlamaForCausalLM"
46
+ and model_config.hf_config.hidden_size == 16384
47
+ and model_config.hf_config.intermediate_size == 53248
48
+ and model_config.hf_config.num_hidden_layers == 126
49
+ and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
50
+ ):
51
+ return True
52
+ return False
53
+
54
+
42
55
  class ModelRunner:
43
56
  def __init__(
44
57
  self,
@@ -119,6 +132,9 @@ class ModelRunner:
119
132
  seed=42,
120
133
  skip_tokenizer_init=True,
121
134
  )
135
+ if is_llama3_405b_fp8(self.model_config):
136
+ self.model_config.hf_config.num_key_value_heads = 8
137
+ vllm_model_config.hf_config.num_key_value_heads = 8
122
138
  self.dtype = vllm_model_config.dtype
123
139
  if self.model_config.model_overide_args is not None:
124
140
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
@@ -241,16 +257,20 @@ class ModelRunner:
241
257
  self.cuda_graph_runner = None
242
258
  return
243
259
 
244
- logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
260
+ logger.info(
261
+ f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
262
+ )
245
263
  batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
246
264
  self.cuda_graph_runner = CudaGraphRunner(
247
- self, max_batch_size_to_capture=max(batch_size_list)
265
+ self,
266
+ max_batch_size_to_capture=max(batch_size_list),
267
+ use_torch_compile=self.server_args.enable_torch_compile,
248
268
  )
249
269
  try:
250
270
  self.cuda_graph_runner.capture(batch_size_list)
251
271
  except RuntimeError as e:
252
272
  raise Exception(
253
- f"Capture cuda graph failed {e}. Possible solutions:\n"
273
+ f"Capture cuda graph failed: {e}. Possible solutions:\n"
254
274
  f"1. disable cuda graph by --disable-cuda-graph\n"
255
275
  f"2. set --mem-fraction-static to a smaller value\n"
256
276
  f"Open an issue on GitHub with reproducible scripts if you need help.\n"
@@ -367,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
367
387
  return model_arch_name_to_cls[model_arch]
368
388
 
369
389
 
390
+ def get_original_weight(loaded_weight, head_dim):
391
+ n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
392
+ dim = loaded_weight.shape[1]
393
+ for i in range(n_kv_head):
394
+ loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
395
+ 2 * i * head_dim : (2 * i + 1) * head_dim, :
396
+ ]
397
+ original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
398
+ assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
399
+ return original_kv_weight
400
+
401
+
402
+ def get_weight_loader_srt(weight_loader):
403
+ def weight_loader_srt(
404
+ self,
405
+ param: Parameter,
406
+ loaded_weight: torch.Tensor,
407
+ loaded_shard_id: Optional[str] = None,
408
+ ):
409
+ if (
410
+ loaded_shard_id in ["k", "v"]
411
+ and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
412
+ ):
413
+ loaded_weight = get_original_weight(loaded_weight, self.head_size)
414
+
415
+ weight_loader(self, param, loaded_weight, loaded_shard_id)
416
+
417
+ return weight_loader_srt
418
+
419
+
370
420
  # Monkey patch model loader
371
421
  setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
422
+ original_weight_loader = QKVParallelLinear.weight_loader
423
+ setattr(
424
+ QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
425
+ )
@@ -14,7 +14,7 @@ class ScheduleHeuristic:
14
14
  tree_cache,
15
15
  ):
16
16
  if tree_cache.disable and schedule_heuristic == "lpm":
17
- # LMP is not meaningless when tree cache is disabled.
17
+ # LMP is meaningless when the tree cache is disabled.
18
18
  schedule_heuristic = "fcfs"
19
19
 
20
20
  self.schedule_heuristic = schedule_heuristic
@@ -28,11 +28,16 @@ class ScheduleHeuristic:
28
28
  # longest prefix match
29
29
  forward_queue.sort(key=lambda x: -len(x.prefix_indices))
30
30
  return forward_queue
31
+ elif self.schedule_heuristic == "fcfs":
32
+ # first come first serve
33
+ return forward_queue
34
+ elif self.schedule_heuristic == "lof":
35
+ # longest output first
36
+ forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
37
+ return forward_queue
31
38
  elif self.schedule_heuristic == "random":
32
39
  random.shuffle(forward_queue)
33
40
  return forward_queue
34
- elif self.schedule_heuristic == "fcfs":
35
- return forward_queue
36
41
  elif self.schedule_heuristic == "dfs-weight":
37
42
  last_node_to_reqs = defaultdict(list)
38
43
  for req in forward_queue:
@@ -103,6 +103,9 @@ class ModelTpServer:
103
103
  if server_args.max_running_requests is None
104
104
  else server_args.max_running_requests
105
105
  )
106
+ self.max_running_requests = min(
107
+ self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
108
+ )
106
109
  self.int_token_logit_bias = torch.tensor(
107
110
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
108
111
  )
@@ -113,13 +116,9 @@ class ModelTpServer:
113
116
  f"[gpu_id={self.gpu_id}] "
114
117
  f"max_total_num_tokens={self.max_total_num_tokens}, "
115
118
  f"max_prefill_tokens={self.max_prefill_tokens}, "
119
+ f"max_running_requests={self.max_running_requests}, "
116
120
  f"context_len={self.model_config.context_len}"
117
121
  )
118
- if self.tp_rank == 0:
119
- logger.info(
120
- f"[gpu_id={self.gpu_id}] "
121
- f"server_args: {server_args.print_mode_args()}"
122
- )
123
122
 
124
123
  # Init cache
125
124
  self.tree_cache = RadixCache(
@@ -161,15 +160,12 @@ class ModelTpServer:
161
160
  assert (
162
161
  server_args.schedule_conservativeness >= 0
163
162
  ), "Invalid schedule_conservativeness"
164
- self.new_token_ratio = min(
165
- global_config.base_new_token_ratio * server_args.schedule_conservativeness,
166
- 1.0,
167
- )
168
163
  self.min_new_token_ratio = min(
169
164
  global_config.base_min_new_token_ratio
170
165
  * server_args.schedule_conservativeness,
171
166
  1.0,
172
167
  )
168
+ self.new_token_ratio = self.min_new_token_ratio
173
169
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
174
170
  self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
175
171
 
@@ -231,6 +227,7 @@ class ModelTpServer:
231
227
  break
232
228
  else:
233
229
  self.check_memory()
230
+ self.new_token_ratio = global_config.init_new_token_ratio
234
231
 
235
232
  def print_stats(self):
236
233
  num_used = self.max_total_num_tokens - (
@@ -539,9 +536,10 @@ class ModelTpServer:
539
536
  # Check if decode out of memory
540
537
  if not batch.check_decode_mem():
541
538
  old_ratio = self.new_token_ratio
542
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
543
539
 
544
- retracted_reqs = batch.retract_decode()
540
+ retracted_reqs, new_token_ratio = batch.retract_decode()
541
+ self.new_token_ratio = new_token_ratio
542
+
545
543
  logger.info(
546
544
  "decode out of memory happened, "
547
545
  f"#retracted_reqs: {len(retracted_reqs)}, "
sglang/srt/memory_pool.py CHANGED
@@ -11,6 +11,7 @@ class ReqToTokenPool:
11
11
  """A memory pool that maps a request to its token locations."""
12
12
 
13
13
  def __init__(self, size: int, max_context_len: int):
14
+ self.size = size
14
15
  self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
15
16
  self.req_to_token = torch.empty(
16
17
  (size, max_context_len), dtype=torch.int32, device="cuda"
@@ -57,9 +58,13 @@ class TokenToKVPool:
57
58
  # We also add one slot. This slot is used for writing dummy output from padded tokens.
58
59
  self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
59
60
 
60
- # [size, key/value, head_num, head_dim] for each layer
61
- self.kv_data = [
62
- torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
61
+ # [size, head_num, head_dim] for each layer
62
+ self.k_buffer = [
63
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
64
+ for _ in range(layer_num)
65
+ ]
66
+ self.v_buffer = [
67
+ torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
63
68
  for _ in range(layer_num)
64
69
  ]
65
70
 
@@ -71,10 +76,13 @@ class TokenToKVPool:
71
76
  self.clear()
72
77
 
73
78
  def get_key_buffer(self, layer_id: int):
74
- return self.kv_data[layer_id][:, 0]
79
+ return self.k_buffer[layer_id]
75
80
 
76
81
  def get_value_buffer(self, layer_id: int):
77
- return self.kv_data[layer_id][:, 1]
82
+ return self.v_buffer[layer_id]
83
+
84
+ def get_kv_buffer(self, layer_id: int):
85
+ return self.k_buffer[layer_id], self.v_buffer[layer_id]
78
86
 
79
87
  def available_size(self):
80
88
  return self.can_use_mem_size + len(self.prefetch_buffer)