sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -59,6 +59,9 @@ class GenerateReqInput:
59
59
  return_text_in_logprobs: bool = False
60
60
  # Whether to stream output.
61
61
  stream: bool = False
62
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
63
+ log_metrics: bool = True
64
+
62
65
  # The modalities of the image data [image, multi-images, video]
63
66
  modalities: Optional[List[str]] = None
64
67
  # LoRA related
@@ -66,6 +69,8 @@ class GenerateReqInput:
66
69
 
67
70
  # Session info for continual prompting
68
71
  session_params: Optional[Union[List[Dict], Dict]] = None
72
+ # Custom logit processor (serialized function)
73
+ custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
69
74
 
70
75
  def normalize_batch_and_arguments(self):
71
76
  if (
@@ -180,6 +185,13 @@ class GenerateReqInput:
180
185
  else:
181
186
  assert self.parallel_sample_num == 1
182
187
 
188
+ if self.custom_logit_processor is None:
189
+ self.custom_logit_processor = [None] * num
190
+ elif not isinstance(self.custom_logit_processor, list):
191
+ self.custom_logit_processor = [self.custom_logit_processor] * num
192
+ else:
193
+ assert self.parallel_sample_num == 1
194
+
183
195
  def regenerate_rid(self):
184
196
  self.rid = uuid.uuid4().hex
185
197
  return self.rid
@@ -196,8 +208,14 @@ class GenerateReqInput:
196
208
  top_logprobs_num=self.top_logprobs_num[i],
197
209
  return_text_in_logprobs=self.return_text_in_logprobs,
198
210
  stream=self.stream,
211
+ log_metrics=self.log_metrics,
199
212
  modalities=self.modalities[i] if self.modalities else None,
200
213
  lora_path=self.lora_path[i] if self.lora_path is not None else None,
214
+ custom_logit_processor=(
215
+ self.custom_logit_processor[i]
216
+ if self.custom_logit_processor is not None
217
+ else None
218
+ ),
201
219
  )
202
220
 
203
221
 
@@ -230,6 +248,10 @@ class TokenizedGenerateReqInput:
230
248
  # Session info for continual prompting
231
249
  session_params: Optional[SessionParams] = None
232
250
 
251
+ # Custom logit processor (serialized function)
252
+ # TODO (hpguo): Add an example and update doc string here
253
+ custom_logit_processor: Optional[str] = None
254
+
233
255
 
234
256
  @dataclass
235
257
  class EmbeddingReqInput:
@@ -243,6 +265,8 @@ class EmbeddingReqInput:
243
265
  sampling_params: Union[List[Dict], Dict] = None
244
266
  # Dummy input embeds for compatibility
245
267
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
268
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
269
+ log_metrics: bool = True
246
270
 
247
271
  def normalize_batch_and_arguments(self):
248
272
  if (self.text is None and self.input_ids is None) or (
@@ -340,7 +364,6 @@ class BatchTokenIDOut:
340
364
  input_top_logprobs_idx: List[List]
341
365
  output_top_logprobs_val: List[List]
342
366
  output_top_logprobs_idx: List[List]
343
- normalized_prompt_logprob: List[float]
344
367
 
345
368
 
346
369
  @dataclass
@@ -366,7 +389,6 @@ class BatchStrOut:
366
389
  input_top_logprobs_idx: List[List]
367
390
  output_top_logprobs_val: List[List]
368
391
  output_top_logprobs_idx: List[List]
369
- normalized_prompt_logprob: List[float]
370
392
 
371
393
 
372
394
  @dataclass
@@ -491,6 +513,7 @@ class ProfileReq(Enum):
491
513
  @dataclass
492
514
  class ConfigureLoggingReq:
493
515
  log_requests: Optional[bool] = None
516
+ log_requests_level: Optional[int] = None
494
517
  dump_requests_folder: Optional[str] = None
495
518
  dump_requests_threshold: Optional[int] = None
496
519
 
@@ -52,7 +52,6 @@ from sglang.srt.server_args import ServerArgs
52
52
  if TYPE_CHECKING:
53
53
  from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
54
54
 
55
-
56
55
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
57
56
 
58
57
  # Put some global args for easy access
@@ -65,9 +64,9 @@ global_server_args_dict = {
65
64
  "enable_nan_detection": ServerArgs.enable_nan_detection,
66
65
  "enable_dp_attention": ServerArgs.enable_dp_attention,
67
66
  "enable_ep_moe": ServerArgs.enable_ep_moe,
67
+ "device": ServerArgs.device,
68
68
  }
69
69
 
70
-
71
70
  logger = logging.getLogger(__name__)
72
71
 
73
72
 
@@ -116,14 +115,18 @@ class FINISH_LENGTH(BaseFinishReason):
116
115
 
117
116
 
118
117
  class FINISH_ABORT(BaseFinishReason):
119
- def __init__(self, message="Unknown error"):
118
+ def __init__(self, message="Unknown error", status_code=None, err_type=None):
120
119
  super().__init__(is_error=True)
121
120
  self.message = message
121
+ self.status_code = status_code
122
+ self.err_type = err_type
122
123
 
123
124
  def to_json(self):
124
125
  return {
125
126
  "type": "abort",
126
127
  "message": self.message,
128
+ "status_code": self.status_code,
129
+ "err_type": self.err_type,
127
130
  }
128
131
 
129
132
 
@@ -148,6 +151,15 @@ class ImageInputs:
148
151
  image_grid_thws: List[Tuple[int, int, int]] = None
149
152
  mrope_position_delta: Optional[torch.Tensor] = None
150
153
 
154
+ # MiniCPMV related
155
+ # All the images in the batch should share the same special image
156
+ # bound token ids.
157
+ im_start_id: Optional[torch.Tensor] = None
158
+ im_end_id: Optional[torch.Tensor] = None
159
+ slice_start_id: Optional[torch.Tensor] = None
160
+ slice_end_id: Optional[torch.Tensor] = None
161
+ tgt_sizes: Optional[list] = None
162
+
151
163
  @staticmethod
152
164
  def from_dict(obj: dict):
153
165
  ret = ImageInputs(
@@ -167,6 +179,11 @@ class ImageInputs:
167
179
  "aspect_ratio_ids",
168
180
  "aspect_ratio_mask",
169
181
  "image_grid_thws",
182
+ "im_start_id",
183
+ "im_end_id",
184
+ "slice_start_id",
185
+ "slice_end_id",
186
+ "tgt_sizes",
170
187
  ]
171
188
  for arg in optional_args:
172
189
  if arg in obj:
@@ -215,6 +232,7 @@ class Req:
215
232
  lora_path: Optional[str] = None,
216
233
  input_embeds: Optional[List[List[float]]] = None,
217
234
  session_id: Optional[str] = None,
235
+ custom_logit_processor: Optional[str] = None,
218
236
  eos_token_ids: Optional[Set[int]] = None,
219
237
  ):
220
238
  # Input and output info
@@ -226,14 +244,16 @@ class Req:
226
244
  else origin_input_ids # Before image padding
227
245
  )
228
246
  self.origin_input_ids = origin_input_ids
229
- self.output_ids = [] # Each decode stage's output ids
230
- self.fill_ids = None # fill_ids = origin_input_ids + output_ids
247
+ # Each decode stage's output ids
248
+ self.output_ids = []
249
+ # fill_ids = origin_input_ids + output_ids. Updated if chunked.
231
250
  self.session_id = session_id
232
251
  self.input_embeds = input_embeds
233
252
 
234
253
  # Sampling info
235
254
  self.sampling_params = sampling_params
236
255
  self.lora_path = lora_path
256
+ self.custom_logit_processor = custom_logit_processor
237
257
 
238
258
  # Memory pool info
239
259
  self.req_pool_idx = None
@@ -265,6 +285,7 @@ class Req:
265
285
  # Prefix info
266
286
  self.prefix_indices = []
267
287
  # Tokens to run prefill. input_tokens - shared_prefix_tokens.
288
+ # Updated if chunked.
268
289
  self.extend_input_len = 0
269
290
  self.last_node = None
270
291
 
@@ -280,11 +301,10 @@ class Req:
280
301
  self.top_logprobs_num = top_logprobs_num
281
302
 
282
303
  # Logprobs (return value)
283
- self.normalized_prompt_logprob = None
284
- self.input_token_logprobs_val = None
285
- self.input_token_logprobs_idx = None
286
- self.input_top_logprobs_val = None
287
- self.input_top_logprobs_idx = None
304
+ self.input_token_logprobs_val: Optional[List[float]] = None
305
+ self.input_token_logprobs_idx: Optional[List[int]] = None
306
+ self.input_top_logprobs_val: Optional[List[float]] = None
307
+ self.input_top_logprobs_idx: Optional[List[int]] = None
288
308
 
289
309
  if return_logprob:
290
310
  self.output_token_logprobs_val = []
@@ -344,9 +364,6 @@ class Req:
344
364
  max_prefix_len = min(max_prefix_len, input_len - 1)
345
365
 
346
366
  if self.return_logprob:
347
- if self.normalized_prompt_logprob is None:
348
- # Need at least two tokens to compute normalized logprob
349
- max_prefix_len = min(max_prefix_len, input_len - 2)
350
367
  max_prefix_len = min(max_prefix_len, self.logprob_start_len)
351
368
 
352
369
  max_prefix_len = max(max_prefix_len, 0)
@@ -578,6 +595,9 @@ class ScheduleBatch:
578
595
  spec_algorithm: SpeculativeAlgorithm = None
579
596
  spec_info: Optional[SpecInfo] = None
580
597
 
598
+ # Enable custom logit processor
599
+ enable_custom_logit_processor: bool = False
600
+
581
601
  @classmethod
582
602
  def init_new(
583
603
  cls,
@@ -588,6 +608,7 @@ class ScheduleBatch:
588
608
  model_config: ModelConfig,
589
609
  enable_overlap: bool,
590
610
  spec_algorithm: SpeculativeAlgorithm,
611
+ enable_custom_logit_processor: bool,
591
612
  ):
592
613
  return cls(
593
614
  reqs=reqs,
@@ -601,6 +622,7 @@ class ScheduleBatch:
601
622
  has_grammar=any(req.grammar for req in reqs),
602
623
  device=req_to_token_pool.device,
603
624
  spec_algorithm=spec_algorithm,
625
+ enable_custom_logit_processor=enable_custom_logit_processor,
604
626
  )
605
627
 
606
628
  def batch_size(self):
@@ -656,7 +678,7 @@ class ScheduleBatch:
656
678
  or len(req.prefix_indices) >= im.num_image_tokens
657
679
  )
658
680
 
659
- self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
681
+ self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to(
660
682
  self.device, non_blocking=True
661
683
  )
662
684
 
@@ -690,7 +712,7 @@ class ScheduleBatch:
690
712
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
691
713
  self.device, non_blocking=True
692
714
  )
693
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
715
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
694
716
  self.device, non_blocking=True
695
717
  )
696
718
 
@@ -766,10 +788,10 @@ class ScheduleBatch:
766
788
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
767
789
  self.device, non_blocking=True
768
790
  )
769
- self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
791
+ self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to(
770
792
  self.device, non_blocking=True
771
793
  )
772
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
794
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
773
795
  self.device, non_blocking=True
774
796
  )
775
797
  self.input_embeds = (
@@ -1002,11 +1024,16 @@ class ScheduleBatch:
1002
1024
  def prepare_for_idle(self):
1003
1025
  self.forward_mode = ForwardMode.IDLE
1004
1026
  self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
1005
- self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
1027
+ self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
1006
1028
  self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
1007
- self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
1029
+ self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
1008
1030
  self.seq_lens_sum = 0
1009
1031
  self.extend_num_tokens = 0
1032
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
1033
+ self,
1034
+ self.model_config.vocab_size,
1035
+ enable_overlap_schedule=self.enable_overlap,
1036
+ )
1010
1037
 
1011
1038
  def prepare_for_decode(self):
1012
1039
  self.forward_mode = ForwardMode.DECODE
@@ -1067,7 +1094,7 @@ class ScheduleBatch:
1067
1094
  self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
1068
1095
 
1069
1096
  self.reqs = [self.reqs[i] for i in keep_indices]
1070
- new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
1097
+ new_indices = torch.tensor(keep_indices, dtype=torch.int64).to(
1071
1098
  self.device, non_blocking=True
1072
1099
  )
1073
1100
  self.req_pool_indices = self.req_pool_indices[new_indices]
@@ -1121,7 +1148,7 @@ class ScheduleBatch:
1121
1148
  self.spec_info.merge_batch(other.spec_info)
1122
1149
 
1123
1150
  def get_model_worker_batch(self):
1124
- if self.forward_mode.is_decode() or self.forward_mode.is_idle():
1151
+ if self.forward_mode.is_decode_or_idle():
1125
1152
  extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
1126
1153
  else:
1127
1154
  extend_seq_lens = self.extend_lens
@@ -1136,7 +1163,6 @@ class ScheduleBatch:
1136
1163
 
1137
1164
  global bid
1138
1165
  bid += 1
1139
-
1140
1166
  return ModelWorkerBatch(
1141
1167
  bid=bid,
1142
1168
  forward_mode=self.forward_mode,
@@ -1180,6 +1206,7 @@ class ScheduleBatch:
1180
1206
  return_logprob=self.return_logprob,
1181
1207
  decoding_reqs=self.decoding_reqs,
1182
1208
  spec_algorithm=self.spec_algorithm,
1209
+ enable_custom_logit_processor=self.enable_custom_logit_processor,
1183
1210
  )
1184
1211
 
1185
1212
  def __str__(self):
@@ -24,6 +24,7 @@ import torch
24
24
 
25
25
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
26
26
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
27
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
27
28
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
28
29
 
29
30
  # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
@@ -250,23 +251,24 @@ class PrefillAdder:
250
251
  def __init__(
251
252
  self,
252
253
  tree_cache: BasePrefixCache,
254
+ token_to_kv_pool: BaseTokenToKVPool,
253
255
  running_batch: ScheduleBatch,
254
256
  new_token_ratio: float,
255
- rem_total_tokens: int,
256
257
  rem_input_tokens: int,
257
258
  rem_chunk_tokens: Optional[int],
258
259
  mixed_with_decode_tokens: int = 0,
259
260
  ):
260
261
  self.tree_cache = tree_cache
262
+ self.token_to_kv_pool = token_to_kv_pool
261
263
  self.running_batch = running_batch
262
264
  self.new_token_ratio = new_token_ratio
263
- self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
264
265
  self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
265
266
  self.rem_chunk_tokens = rem_chunk_tokens
266
267
  if self.rem_chunk_tokens is not None:
267
268
  self.rem_chunk_tokens -= mixed_with_decode_tokens
268
269
 
269
- self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
270
+ self.rem_total_token_offset = mixed_with_decode_tokens
271
+ self.cur_rem_token_offset = mixed_with_decode_tokens
270
272
 
271
273
  self.req_states = None
272
274
  self.can_run_list = []
@@ -275,8 +277,7 @@ class PrefillAdder:
275
277
  self.log_input_tokens = 0
276
278
 
277
279
  if running_batch is not None:
278
- # Pre-remove the tokens which will be occupied by the running requests
279
- self.rem_total_tokens -= sum(
280
+ self.rem_total_token_offset += sum(
280
281
  [
281
282
  min(
282
283
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
@@ -287,6 +288,22 @@ class PrefillAdder:
287
288
  ]
288
289
  )
289
290
 
291
+ @property
292
+ def rem_total_tokens(self):
293
+ return (
294
+ self.token_to_kv_pool.available_size()
295
+ + self.tree_cache.evictable_size()
296
+ - self.rem_total_token_offset
297
+ )
298
+
299
+ @property
300
+ def cur_rem_tokens(self):
301
+ return (
302
+ self.token_to_kv_pool.available_size()
303
+ + self.tree_cache.evictable_size()
304
+ - self.cur_rem_token_offset
305
+ )
306
+
290
307
  def budget_state(self):
291
308
  if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
292
309
  return AddReqResult.NO_TOKEN
@@ -301,8 +318,8 @@ class PrefillAdder:
301
318
  def _prefill_one_req(
302
319
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
303
320
  ):
304
- self.rem_total_tokens -= extend_input_len + max_new_tokens
305
- self.cur_rem_tokens -= extend_input_len
321
+ self.rem_total_token_offset += extend_input_len + max_new_tokens
322
+ self.cur_rem_token_offset += extend_input_len
306
323
  self.rem_input_tokens -= extend_input_len
307
324
  if self.rem_chunk_tokens is not None:
308
325
  self.rem_chunk_tokens -= extend_input_len
@@ -332,12 +349,10 @@ class PrefillAdder:
332
349
  @contextmanager
333
350
  def _lock_node(self, last_node: TreeNode):
334
351
  try:
335
- delta = self.tree_cache.inc_lock_ref(last_node)
336
- self.rem_total_tokens += delta
352
+ self.tree_cache.inc_lock_ref(last_node)
337
353
  yield None
338
354
  finally:
339
- delta = self.tree_cache.dec_lock_ref(last_node)
340
- self.rem_total_tokens += delta
355
+ self.tree_cache.dec_lock_ref(last_node)
341
356
 
342
357
  def add_one_req_ignore_eos(self, req: Req):
343
358
  def add_req_state(r, insert_sort=False):
@@ -433,7 +448,6 @@ class PrefillAdder:
433
448
  or input_tokens <= self.rem_chunk_tokens
434
449
  or (
435
450
  req.return_logprob
436
- and req.normalized_prompt_logprob is None
437
451
  and req.logprob_start_len != len(req.origin_input_ids) - 1
438
452
  )
439
453
  ):