sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -226,11 +226,11 @@ class GenerateReqInput:
226
226
 
227
227
  # Expand input based on type
228
228
  self._expand_inputs(num)
229
+ self._normalize_rid(num)
229
230
  self._normalize_lora_paths(num)
230
231
  self._normalize_image_data(num)
231
232
  self._normalize_audio_data(num)
232
233
  self._normalize_sampling_params(num)
233
- self._normalize_rid(num)
234
234
  self._normalize_logprob_params(num)
235
235
  self._normalize_custom_logit_processor(num)
236
236
 
@@ -530,6 +530,7 @@ class EmbeddingReqInput:
530
530
  if self.text is not None:
531
531
  if isinstance(self.text, list):
532
532
  self.batch_size += len(self.text)
533
+ self.is_single = False
533
534
  else:
534
535
  self.batch_size += 1
535
536
 
@@ -537,12 +538,10 @@ class EmbeddingReqInput:
537
538
  if self.input_ids is not None:
538
539
  if isinstance(self.input_ids[0], list):
539
540
  self.batch_size += len(self.input_ids)
541
+ self.is_single = False
540
542
  else:
541
543
  self.batch_size += 1
542
544
 
543
- if self.batch_size > 1:
544
- self.is_single = False
545
-
546
545
  # Fill in default arguments
547
546
  if self.is_single:
548
547
  if self.rid is None:
@@ -812,7 +811,9 @@ class GetWeightsByNameReqOutput:
812
811
 
813
812
  @dataclass
814
813
  class ReleaseMemoryOccupationReqInput:
815
- pass
814
+ # Optional tags to identify the memory region, which is primarily used for RL
815
+ # Currently we only support `weights` and `kv_cache`
816
+ tags: Optional[List[str]] = None
816
817
 
817
818
 
818
819
  @dataclass
@@ -822,7 +823,9 @@ class ReleaseMemoryOccupationReqOutput:
822
823
 
823
824
  @dataclass
824
825
  class ResumeMemoryOccupationReqInput:
825
- pass
826
+ # Optional tags to identify the memory region, which is primarily used for RL
827
+ # Currently we only support `weights` and `kv_cache`
828
+ tags: Optional[List[str]] = None
826
829
 
827
830
 
828
831
  @dataclass
@@ -861,12 +864,6 @@ class SetInternalStateReq:
861
864
  server_args: Dict[str, Any]
862
865
 
863
866
 
864
- @dataclass
865
- class V1RerankReqInput:
866
- query: str
867
- documents: List[str]
868
-
869
-
870
867
  @dataclass
871
868
  class SetInternalStateReqOutput:
872
869
  updated: bool
@@ -38,7 +38,7 @@ import logging
38
38
  import threading
39
39
  from enum import Enum, auto
40
40
  from http import HTTPStatus
41
- from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
41
+ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
42
42
 
43
43
  import numpy as np
44
44
  import torch
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
54
54
  )
55
55
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
56
  from sglang.srt.layers.multimodal import gpu_tensor_hash
57
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
57
58
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
58
59
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
59
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
60
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
60
61
  from sglang.srt.metrics.collector import TimeStats
61
62
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
62
63
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -85,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
85
86
  "enable_deepep_moe",
86
87
  "deepep_mode",
87
88
  "enable_ep_moe",
89
+ "enable_flashinfer_moe",
88
90
  "moe_dense_tp_size",
89
91
  "ep_dispatch_algorithm",
90
92
  "deepep_config",
@@ -99,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
99
101
  "torchao_config",
100
102
  "triton_attention_reduce_in_fp32",
101
103
  "num_reserved_decode_tokens",
104
+ "weight_loader_disable_mmap",
102
105
  ]
103
106
 
104
107
  # Put some global args for easy access
@@ -436,7 +439,7 @@ class Req:
436
439
  self,
437
440
  rid: str,
438
441
  origin_input_text: str,
439
- origin_input_ids: Tuple[int],
442
+ origin_input_ids: List[int],
440
443
  sampling_params: SamplingParams,
441
444
  return_logprob: bool = False,
442
445
  top_logprobs_num: int = 0,
@@ -467,7 +470,7 @@ class Req:
467
470
  # Each decode stage's output ids
468
471
  self.output_ids = []
469
472
  # fill_ids = origin_input_ids + output_ids. Updated if chunked.
470
- self.fill_ids = None
473
+ self.fill_ids = []
471
474
  self.session_id = session_id
472
475
  self.input_embeds = input_embeds
473
476
 
@@ -519,13 +522,14 @@ class Req:
519
522
 
520
523
  # Prefix info
521
524
  # The indices to kv cache for the shared prefix.
522
- self.prefix_indices = []
525
+ self.prefix_indices: torch.Tensor = []
523
526
  # Number of tokens to run prefill.
524
527
  self.extend_input_len = 0
525
528
  # The relative logprob_start_len in an extend batch
526
529
  self.extend_logprob_start_len = 0
527
- self.last_node = None
528
- self.last_node_global = None
530
+ self.last_node: Any = None
531
+ self.last_host_node: Any = None
532
+ self.host_hit_length = 0
529
533
 
530
534
  # Whether or not if it is chunked. It increments whenever
531
535
  # it is chunked, and decrement whenever chunked request is
@@ -583,6 +587,7 @@ class Req:
583
587
  self.output_token_ids_logprobs_idx
584
588
  ) = None
585
589
  self.hidden_states: List[List[float]] = []
590
+ self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
586
591
 
587
592
  # Embedding (return values)
588
593
  self.embedding = None
@@ -644,29 +649,17 @@ class Req:
644
649
  def init_next_round_input(
645
650
  self,
646
651
  tree_cache: Optional[BasePrefixCache] = None,
647
- enable_hierarchical_cache=False,
648
652
  ):
649
653
  self.fill_ids = self.origin_input_ids + self.output_ids
650
654
  if tree_cache is not None:
651
- # tree cache is None if the prefix is not computed with tree cache.
652
- if enable_hierarchical_cache:
653
- self.prefix_indices, self.last_node, self.last_node_global = (
654
- tree_cache.match_prefix(
655
- key=self.adjust_max_prefix_ids(), include_evicted=True
656
- )
657
- )
658
- else:
659
- self.prefix_indices, self.last_node = tree_cache.match_prefix(
660
- rid=self.rid, key=self.adjust_max_prefix_ids()
661
- )
662
- elif enable_hierarchical_cache:
663
- # in case last_node is evicted during scheduling, we need to update the prefix_indices
664
- while self.last_node.evicted:
665
- self.prefix_indices = self.prefix_indices[
666
- : -len(self.last_node.host_value)
667
- ]
668
- self.last_node = self.last_node.parent
669
-
655
+ (
656
+ self.prefix_indices,
657
+ self.last_node,
658
+ self.last_host_node,
659
+ self.host_hit_length,
660
+ ) = tree_cache.match_prefix(
661
+ key=self.adjust_max_prefix_ids(),
662
+ )
670
663
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
671
664
 
672
665
  def adjust_max_prefix_ids(self):
@@ -796,6 +789,7 @@ class Req:
796
789
  self.multimodal_inputs = None
797
790
  self.grammar = None
798
791
  self.origin_input_ids = [0] # set it to one token to skip the long prefill
792
+ self.return_logprob = False
799
793
  self.finished_reason = FINISH_ABORT(
800
794
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
801
795
  )
@@ -820,7 +814,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
820
814
  # Request, memory pool, and cache
821
815
  reqs: List[Req]
822
816
  req_to_token_pool: ReqToTokenPool = None
823
- token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
817
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
824
818
  tree_cache: BasePrefixCache = None
825
819
 
826
820
  # Batch configs
@@ -862,6 +856,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
862
856
  global_num_tokens: Optional[List[int]] = None
863
857
  global_num_tokens_for_logprob: Optional[List[int]] = None
864
858
  can_run_dp_cuda_graph: bool = False
859
+ is_extend_in_batch: bool = False
865
860
  tbo_split_seq_index: Optional[int] = None
866
861
  global_forward_mode: Optional[ForwardMode] = None
867
862
 
@@ -908,12 +903,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
908
903
  # Whether to return hidden states
909
904
  return_hidden_states: bool = False
910
905
 
906
+ # hicache pointer for synchronizing data loading from CPU to GPU
907
+ hicache_consumer_index: int = 0
908
+
911
909
  @classmethod
912
910
  def init_new(
913
911
  cls,
914
912
  reqs: List[Req],
915
913
  req_to_token_pool: ReqToTokenPool,
916
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
914
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
917
915
  tree_cache: BasePrefixCache,
918
916
  model_config: ModelConfig,
919
917
  enable_overlap: bool,
@@ -1365,7 +1363,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1365
1363
  return len(self.reqs)
1366
1364
  # In the decoding phase, the length of a request's KV cache should be
1367
1365
  # the total length of the request minus 1
1368
- return sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1366
+ return (
1367
+ sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1368
+ if self.enable_overlap
1369
+ else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1370
+ )
1369
1371
 
1370
1372
  def check_decode_mem(self, buf_multiplier=1):
1371
1373
  tokens_required = (
@@ -1734,6 +1736,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1734
1736
  token_type_ids=self.token_type_ids,
1735
1737
  spec_algorithm=self.spec_algorithm,
1736
1738
  spec_info=self.spec_info,
1739
+ hicache_consumer_index=self.hicache_consumer_index,
1737
1740
  capture_hidden_mode=(
1738
1741
  CaptureHiddenMode.FULL
1739
1742
  if self.return_hidden_states
@@ -1760,11 +1763,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1760
1763
  decoding_reqs=self.decoding_reqs,
1761
1764
  spec_algorithm=self.spec_algorithm,
1762
1765
  enable_custom_logit_processor=self.enable_custom_logit_processor,
1766
+ global_num_tokens=self.global_num_tokens,
1767
+ global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1768
+ can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1769
+ is_extend_in_batch=self.is_extend_in_batch,
1763
1770
  )
1764
1771
 
1765
1772
  def __str__(self):
1766
1773
  return (
1767
- f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
1774
+ f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1768
1775
  f"#req={(len(self.reqs))})"
1769
1776
  )
1770
1777
 
@@ -1833,6 +1840,8 @@ class ModelWorkerBatch:
1833
1840
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1834
1841
  # If set, the output of the batch contains the hidden states of the run.
1835
1842
  capture_hidden_mode: CaptureHiddenMode = None
1843
+ spec_num_draft_tokens: Optional[int] = None
1844
+ hicache_consumer_index: int = 0
1836
1845
 
1837
1846
  # Overlap event
1838
1847
  launch_done: Optional[threading.Event] = None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # Copyright 2023-2024 SGLang Team
2
4
  # Licensed under the Apache License, Version 2.0 (the "License");
3
5
  # you may not use this file except in compliance with the License.
@@ -18,15 +20,17 @@ import random
18
20
  from collections import defaultdict
19
21
  from contextlib import contextmanager
20
22
  from enum import Enum, auto
21
- from typing import Dict, List, Optional, Set, Union
23
+ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
22
24
 
23
25
  import torch
24
26
 
25
27
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
26
28
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
27
- from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
28
29
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
29
30
 
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
33
+
30
34
  # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
31
35
  # This can prevent the server from being too conservative.
32
36
  # Note that this only clips the estimation in the scheduler but does not change the stop
@@ -51,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
51
55
  )
52
56
 
53
57
 
58
+ IGNORE_EOS_RESERVE_TOKENS = 1
59
+
60
+
54
61
  class CacheAwarePolicy(Enum):
55
62
  """Scheduling policies that are aware of the tree cache."""
56
63
 
@@ -90,7 +97,7 @@ class SchedulePolicy:
90
97
  def calc_priority(self, waiting_queue: List[Req]) -> bool:
91
98
  if self.policy == CacheAgnosticPolicy.FCFS:
92
99
  # A shortcut for FCFS
93
- return
100
+ return False
94
101
 
95
102
  policy = self._determine_active_policy(waiting_queue)
96
103
 
@@ -134,7 +141,7 @@ class SchedulePolicy:
134
141
  """
135
142
  try:
136
143
  policy_enum = CacheAwarePolicy(policy)
137
- if tree_cache.disable:
144
+ if getattr(tree_cache, "disable", True):
138
145
  # If tree_cache is disabled, using CacheAgnosticPolicy policy
139
146
  return CacheAgnosticPolicy.FCFS
140
147
  return policy_enum
@@ -158,14 +165,9 @@ class SchedulePolicy:
158
165
  prefix_ids = r.adjust_max_prefix_ids()
159
166
 
160
167
  # NOTE: the prefix_indices must always be aligned with last_node
161
- if self.enable_hierarchical_cache:
162
- r.prefix_indices, r.last_node, r.last_node_global = (
163
- self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
164
- )
165
- else:
166
- r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
167
- rid=r.rid, key=prefix_ids
168
- )
168
+ r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
169
+ self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
170
+ )
169
171
 
170
172
  # NOTE(sang): This logic is for in-batch prefix caching;
171
173
  # If there are more than 1 request that have small matching prefix from
@@ -175,7 +177,7 @@ class SchedulePolicy:
175
177
  # threshold means we cannot use in-batch prefix caching for short prefixes.
176
178
  # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
177
179
  if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
178
- in_batch_matching_prefixes, _ = (
180
+ in_batch_matching_prefixes, _, _, _ = (
179
181
  self.waiting_queue_radix_tree.match_prefix(
180
182
  rid=r.rid, key=prefix_ids
181
183
  )
@@ -268,14 +270,16 @@ class AddReqResult(Enum):
268
270
  class PrefillAdder:
269
271
  def __init__(
270
272
  self,
273
+ page_size: int,
271
274
  tree_cache: BasePrefixCache,
272
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
275
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
273
276
  running_batch: ScheduleBatch,
274
277
  new_token_ratio: float,
275
278
  rem_input_tokens: int,
276
279
  rem_chunk_tokens: Optional[int],
277
280
  mixed_with_decode_tokens: int = 0,
278
281
  ):
282
+ self.page_size = page_size
279
283
  self.tree_cache = tree_cache
280
284
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
281
285
  self.running_batch = running_batch
@@ -292,6 +296,7 @@ class PrefillAdder:
292
296
  self.can_run_list = []
293
297
  self.new_chunked_req = None
294
298
  self.log_hit_tokens = 0
299
+ # TODO(lsyin): report the real input tokens excluding page alignment
295
300
  self.log_input_tokens = 0
296
301
 
297
302
  if running_batch is not None:
@@ -322,6 +327,9 @@ class PrefillAdder:
322
327
  - self.cur_rem_token_offset
323
328
  )
324
329
 
330
+ def ceil_paged_tokens(self, tokens: int) -> int:
331
+ return -(-tokens // self.page_size) * self.page_size
332
+
325
333
  def budget_state(self):
326
334
  if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
327
335
  return AddReqResult.NO_TOKEN
@@ -333,9 +341,12 @@ class PrefillAdder:
333
341
 
334
342
  return AddReqResult.CONTINUE
335
343
 
336
- def _prefill_one_req(
344
+ def _update_prefill_budget(
337
345
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
338
346
  ):
347
+ # TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
348
+ extend_input_len = self.ceil_paged_tokens(extend_input_len)
349
+
339
350
  self.rem_total_token_offset += extend_input_len + max_new_tokens
340
351
  self.cur_rem_token_offset += extend_input_len
341
352
  self.rem_input_tokens -= extend_input_len
@@ -350,7 +361,7 @@ class PrefillAdder:
350
361
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
351
362
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
352
363
  self.can_run_list.append(req)
353
- self._prefill_one_req(
364
+ self._update_prefill_budget(
354
365
  0,
355
366
  req.extend_input_len,
356
367
  (
@@ -372,6 +383,12 @@ class PrefillAdder:
372
383
  self.tree_cache.dec_lock_ref(last_node)
373
384
 
374
385
  def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
386
+ # Early exit if no enough tokens for the input tokens
387
+ if self.ceil_paged_tokens(req.extend_input_len) > min(
388
+ self.cur_rem_tokens, self.rem_total_tokens
389
+ ):
390
+ return AddReqResult.NO_TOKEN
391
+
375
392
  def add_req_state(r, insert_sort=False):
376
393
  new_token_ratio = (
377
394
  1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
@@ -381,15 +398,17 @@ class PrefillAdder:
381
398
  )
382
399
  tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
383
400
 
384
- if tokens_left > 0:
385
- if not insert_sort:
386
- self.req_states.append((tokens_left, tokens_occupied))
387
- else:
388
- i = 0
389
- for i in range(len(self.req_states)):
390
- if tokens_left <= self.req_states[i][0]:
391
- break
392
- self.req_states.insert(i, (tokens_left, tokens_occupied))
401
+ if tokens_left <= 0:
402
+ return
403
+
404
+ if not insert_sort:
405
+ self.req_states.append((tokens_left, tokens_occupied))
406
+ else:
407
+ i = 0
408
+ for i in range(len(self.req_states)):
409
+ if tokens_left <= self.req_states[i][0]:
410
+ break
411
+ self.req_states.insert(i, (tokens_left, tokens_occupied))
393
412
 
394
413
  if self.req_states is None:
395
414
  self.req_states = []
@@ -406,13 +425,11 @@ class PrefillAdder:
406
425
  cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
407
426
  tokens_freed = 0
408
427
  for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
409
- decode_steps = (
410
- self.req_states[i + 1][0]
411
- if i + 1 < len(self.req_states)
412
- else tokens_left
413
- )
428
+ # tokens_left gives a reservative calculation as the last token is not stored
414
429
  bs = len(self.req_states) - i
415
- if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
430
+ min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
431
+ # reserve tokens for corner cases
432
+ if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
416
433
  return AddReqResult.NO_TOKEN
417
434
  tokens_freed += tokens_occupied
418
435
 
@@ -422,7 +439,7 @@ class PrefillAdder:
422
439
  ):
423
440
  # Non-chunked prefill
424
441
  self.can_run_list.append(req)
425
- self._prefill_one_req(
442
+ self._update_prefill_budget(
426
443
  0,
427
444
  req.extend_input_len,
428
445
  min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
@@ -438,55 +455,52 @@ class PrefillAdder:
438
455
  req.fill_ids = req.fill_ids[:trunc_len]
439
456
  self.can_run_list.append(req)
440
457
  self.new_chunked_req = req
441
- self._prefill_one_req(0, trunc_len, 0)
458
+ self._update_prefill_budget(0, trunc_len, 0)
442
459
 
443
460
  return self.budget_state()
444
461
 
445
- def add_one_req(
446
- self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
447
- ):
462
+ def add_one_req(self, req: Req, has_chunked_req: bool):
448
463
  if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
449
464
  return self.add_one_req_ignore_eos(req, has_chunked_req)
450
465
 
451
466
  total_tokens = req.extend_input_len + min(
452
467
  req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
453
468
  )
454
- input_tokens = (
455
- -(-req.extend_input_len // self.tree_cache.page_size)
456
- * self.tree_cache.page_size
457
- )
469
+
470
+ # adjusting the input_tokens based on host_hit_length and page_size
471
+ real_input_tokens = req.extend_input_len - req.host_hit_length
472
+ real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
458
473
  prefix_len = len(req.prefix_indices)
459
474
 
460
475
  if total_tokens >= self.rem_total_tokens:
461
476
  return AddReqResult.NO_TOKEN
462
477
 
463
- if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
478
+ if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
464
479
  return AddReqResult.OTHER
465
480
 
466
481
  with self._lock_node(req.last_node):
467
- if total_tokens > self.rem_total_tokens:
482
+ # self.rem_total_tokens may decrease after the lock acquisition
483
+ if total_tokens >= self.rem_total_tokens:
468
484
  return AddReqResult.NO_TOKEN
469
485
 
470
- if (
471
- enable_hierarchical_cache
472
- and req.last_node_global is not None
473
- and req.last_node_global.evicted
474
- ):
475
- req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
476
- req.last_node_global, req.prefix_indices
486
+ if req.host_hit_length > 0:
487
+ new_indices, req.last_node = self.tree_cache.init_load_back(
488
+ req.last_host_node, req.host_hit_length
477
489
  )
490
+ req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
478
491
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
479
- input_tokens = (
480
- -(-req.extend_input_len // self.tree_cache.page_size)
481
- * self.tree_cache.page_size
482
- )
483
492
  prefix_len = len(req.prefix_indices)
484
493
 
494
+ input_tokens = self.ceil_paged_tokens(req.extend_input_len)
495
+
496
+ if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
497
+ return AddReqResult.OTHER
498
+
485
499
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
486
500
  # Non-chunked prefill
487
501
  self.can_run_list.append(req)
488
502
  self.tree_cache.inc_lock_ref(req.last_node)
489
- self._prefill_one_req(
503
+ self._update_prefill_budget(
490
504
  prefix_len,
491
505
  input_tokens,
492
506
  min(
@@ -496,7 +510,7 @@ class PrefillAdder:
496
510
  )
497
511
  else:
498
512
  # Make sure at least one page is available
499
- trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
513
+ trunc_len = self.rem_chunk_tokens - self.page_size + 1
500
514
  if trunc_len <= 0:
501
515
  return AddReqResult.OTHER
502
516
 
@@ -507,6 +521,6 @@ class PrefillAdder:
507
521
  self.can_run_list.append(req)
508
522
  self.new_chunked_req = req
509
523
  self.tree_cache.inc_lock_ref(req.last_node)
510
- self._prefill_one_req(prefix_len, trunc_len, 0)
524
+ self._update_prefill_budget(prefix_len, trunc_len, 0)
511
525
 
512
526
  return self.budget_state()