sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -1058,12 +1058,7 @@ class TokenizerManager:
1058
1058
  "lora_path",
1059
1059
  ]
1060
1060
  )
1061
- out_skip_names = set(
1062
- [
1063
- "text",
1064
- "output_ids",
1065
- ]
1066
- )
1061
+ out_skip_names = set(["text", "output_ids", "embedding"])
1067
1062
  elif self.log_requests_level == 1:
1068
1063
  max_length = 2048
1069
1064
  elif self.log_requests_level == 2:
@@ -1140,13 +1135,21 @@ class TokenizerManager:
1140
1135
  remain_num_req = len(self.rid_to_state)
1141
1136
 
1142
1137
  if self.health_check_failed:
1143
- # if health check failed, we should exit immediately
1138
+ # if health check failed, exit immediately
1144
1139
  logger.error(
1145
1140
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1146
1141
  remain_num_req,
1147
1142
  )
1148
1143
  break
1149
1144
 
1145
+ elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1146
+ # if force shutdown flag set, exit immediately
1147
+ logger.error(
1148
+ "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1149
+ remain_num_req,
1150
+ )
1151
+ break
1152
+
1150
1153
  logger.info(
1151
1154
  f"Gracefully exiting... remaining number of requests {remain_num_req}"
1152
1155
  )
@@ -1223,7 +1226,7 @@ class TokenizerManager:
1223
1226
  state.last_output_offset = len(state.output_ids)
1224
1227
  else:
1225
1228
  state.output_ids.extend(recv_obj.output_ids[i])
1226
- output_token_ids = state.output_ids
1229
+ output_token_ids = state.output_ids.copy()
1227
1230
 
1228
1231
  out_dict = {
1229
1232
  "output_ids": output_token_ids,
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
35
35
  UpdateWeightsFromTensorReqInput,
36
36
  )
37
37
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
38
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
38
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
39
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
40
41
  from sglang.srt.model_executor.model_runner import ModelRunner
41
42
  from sglang.srt.server_args import ServerArgs
@@ -57,7 +58,7 @@ class TpModelWorker:
57
58
  nccl_port: int,
58
59
  is_draft_worker: bool = False,
59
60
  req_to_token_pool: Optional[ReqToTokenPool] = None,
60
- token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
61
+ token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
61
62
  ):
62
63
  # Parse args
63
64
  self.tp_size = server_args.tp_size
@@ -147,6 +148,15 @@ class TpModelWorker:
147
148
  # A reference make this class has the same member as TpModelWorkerClient
148
149
  self.worker = self
149
150
 
151
+ self.hicache_layer_transfer_counter = None
152
+
153
+ def register_hicache_layer_transfer_counter(self, counter):
154
+ self.hicache_layer_transfer_counter = counter
155
+
156
+ def set_hicache_consumer(self, consumer_index):
157
+ if self.hicache_layer_transfer_counter is not None:
158
+ self.hicache_layer_transfer_counter.set_consumer(consumer_index)
159
+
150
160
  def get_worker_info(self):
151
161
  return (
152
162
  self.max_total_num_tokens,
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
88
88
  if self.device == "cpu":
89
89
  self.scheduler_stream.synchronize = lambda: None # No-op for CPU
90
90
 
91
+ self.hicache_layer_transfer_counter = None
92
+
93
+ def register_hicache_layer_transfer_counter(self, counter):
94
+ self.hicache_layer_transfer_counter = counter
95
+
96
+ def set_hicache_consumer(self, consumer_index):
97
+ if self.hicache_layer_transfer_counter is not None:
98
+ self.hicache_layer_transfer_counter.set_consumer(consumer_index)
99
+
91
100
  def get_worker_info(self):
92
101
  return self.worker.get_worker_info()
93
102
 
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
146
155
  input_ids = model_worker_batch.input_ids
147
156
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
148
157
 
158
+ # update the consumer index of hicache to the running batch
159
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
149
160
  # Run forward
150
161
  logits_output, next_token_ids, can_run_cuda_graph = (
151
162
  self.worker.forward_batch_generation(
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2025 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +19,132 @@ limitations under the License.
17
19
  Page-aligned memory pool.
18
20
  """
19
21
 
22
+ import abc
23
+ from typing import TYPE_CHECKING
24
+
20
25
  import torch
21
26
  import triton
22
27
  import triton.language as tl
23
28
 
24
- from sglang.srt.mem_cache.memory_pool import KVCache
25
29
  from sglang.srt.utils import get_bool_env_var, next_power_of_2
26
30
 
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.mem_cache.memory_pool import KVCache
33
+
34
+
35
+ class BaseTokenToKVPoolAllocator(abc.ABC):
36
+ @abc.abstractmethod
37
+ def __init__(
38
+ self,
39
+ size: int,
40
+ page_size: int,
41
+ dtype: torch.dtype,
42
+ device: str,
43
+ kvcache: KVCache,
44
+ ):
45
+ self.size = size
46
+ self.page_size = page_size
47
+ self.dtype = dtype
48
+ self.device = device
49
+ self._kvcache = kvcache
50
+
51
+ self.free_pages = None
52
+ self.is_not_in_free_group = True
53
+ self.free_group = []
54
+
55
+ def debug_print(self) -> str:
56
+ return ""
57
+
58
+ def available_size(self):
59
+ return len(self.free_pages) * self.page_size
60
+
61
+ def get_kvcache(self):
62
+ return self._kvcache
63
+
64
+ def restore_state(self, free_pages):
65
+ self.free_pages = free_pages
66
+
67
+ def backup_state(self):
68
+ return self.free_pages
69
+
70
+ def free_group_begin(self):
71
+ self.is_not_in_free_group = False
72
+ self.free_group = []
73
+
74
+ def free_group_end(self):
75
+ self.is_not_in_free_group = True
76
+ if self.free_group:
77
+ self.free(torch.cat(self.free_group))
78
+
79
+ def get_cpu_copy(self, *args, **kwargs):
80
+ # FIXME: reuse the get_cpu_copy after paged allocator is implemented
81
+ raise NotImplementedError()
82
+
83
+ def load_cpu_copy(self, *args, **kwargs):
84
+ # FIXME: reuse the load_cpu_copy after paged allocator is implemented
85
+ raise NotImplementedError()
86
+
87
+ def alloc_extend(self, *args, **kwargs):
88
+ raise NotImplementedError("alloc_extend is only for paged allocator")
89
+
90
+ def alloc_decode(self, *args, **kwargs):
91
+ raise NotImplementedError("alloc_decode is only for paged allocator")
92
+
93
+ @abc.abstractmethod
94
+ def clear(self):
95
+ raise NotImplementedError()
96
+
97
+ @abc.abstractmethod
98
+ def alloc(self, need_size: int):
99
+ raise NotImplementedError()
100
+
101
+ @abc.abstractmethod
102
+ def free(self, free_index: torch.Tensor):
103
+ raise NotImplementedError()
104
+
105
+
106
+ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
107
+ """An allocator managing the indices to kv cache data."""
108
+
109
+ def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
110
+ super().__init__(size, 1, dtype, device, kvcache)
111
+ self.clear()
112
+
113
+ def clear(self):
114
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
115
+ self.free_pages = torch.arange(
116
+ 1, self.size + 1, dtype=torch.int64, device=self.device
117
+ )
118
+ self.is_not_in_free_group = True
119
+ self.free_group = []
120
+
121
+ def available_size(self):
122
+ # To avoid minor "len(free_pages) * 1" overhead
123
+ return len(self.free_pages)
124
+
125
+ def alloc(self, need_size: int):
126
+ if need_size > len(self.free_pages):
127
+ return None
128
+
129
+ select_index = self.free_pages[:need_size]
130
+ self.free_pages = self.free_pages[need_size:]
131
+ return select_index
132
+
133
+ def free(self, free_index: torch.Tensor):
134
+ if free_index.numel() == 0:
135
+ return
136
+
137
+ if self.is_not_in_free_group:
138
+ self.free_pages = torch.cat((self.free_pages, free_index))
139
+ else:
140
+ self.free_group.append(free_index)
141
+
142
+ def get_cpu_copy(self, indices):
143
+ return self._kvcache.get_cpu_copy(indices)
144
+
145
+ def load_cpu_copy(self, kv_cache_cpu, indices):
146
+ return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
147
+
27
148
 
28
149
  @triton.jit
29
150
  def alloc_extend_kernel(
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
154
275
  tl.store(out_indices + pid, page * page_size)
155
276
 
156
277
 
157
- class PagedTokenToKVPoolAllocator:
278
+ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
158
279
  """
159
280
  An allocator managing the indices to kv cache data.
160
281
 
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
172
293
  device: str,
173
294
  kvcache: KVCache,
174
295
  ):
175
- self.size = size
176
- self.dtype = dtype
177
- self.device = device
178
- self.page_size = page_size
296
+ super().__init__(size, page_size, dtype, device, kvcache)
179
297
  self.num_pages = size // page_size
180
-
181
- self.free_pages = None
182
- self.is_not_in_free_group = True
183
- self.free_group = []
184
- self.clear()
185
298
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
186
-
187
- self._kvcache = kvcache
188
299
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
189
-
190
- def available_size(self):
191
- return len(self.free_pages) * self.page_size
192
-
193
- def get_kvcache(self):
194
- return self._kvcache
300
+ self.clear()
195
301
 
196
302
  def alloc(self, need_size: int):
197
303
  # page-aligned allocation, returning contiguous indices of pages
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
298
404
  if self.debug_mode:
299
405
  assert len(torch.unique(self.free_pages)) == len(self.free_pages)
300
406
 
301
- def free_group_begin(self):
302
- self.is_not_in_free_group = False
303
- self.free_group = []
304
-
305
- def free_group_end(self):
306
- self.is_not_in_free_group = True
307
- if self.free_group:
308
- self.free(torch.cat(self.free_group))
309
-
310
- def backup_state(self):
311
- return self.free_pages
312
-
313
- def restore_state(self, free_pages):
314
- self.free_pages = free_pages
315
-
316
407
  def clear(self):
317
408
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
318
409
  self.free_pages = torch.arange(
@@ -1,5 +1,31 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, List, Tuple
2
+ from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
3
+
4
+ import torch
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.managers.schedule_batch import Req
8
+ else:
9
+ Req = Any # Placeholder for Req type when not type checking
10
+
11
+
12
+ class MatchResult(NamedTuple):
13
+ """Result of a prefix match operation.
14
+
15
+ Attributes:
16
+ device_indices : Indices of the KV cache on the device matched by common prefix.
17
+ last_device_node: The last TreeNode on the device that was matched.
18
+ last_host_node : The last TreeNode on the host that was matched.
19
+ Note that if HiCache is not enabled,
20
+ this **must** be the same as `last_device_node`.
21
+ host_hit_length : Length of the KV cache hit on the host, if applicable.
22
+ 0 if HiCache is not enabled.
23
+ """
24
+
25
+ device_indices: torch.Tensor
26
+ last_device_node: Any
27
+ last_host_node: Any
28
+ host_hit_length: int = 0
3
29
 
4
30
 
5
31
  class BasePrefixCache(ABC):
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
10
36
  pass
11
37
 
12
38
  @abstractmethod
13
- def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
39
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
14
40
  pass
15
41
 
16
42
  @abstractmethod
17
- def insert(self, **kwargs):
43
+ def cache_finished_req(self, req: Req, **kwargs):
18
44
  pass
19
45
 
20
46
  @abstractmethod
21
- def cache_finished_req(self, **kwargs):
22
- pass
23
-
24
- @abstractmethod
25
- def cache_unfinished_req(self, **kwargs):
47
+ def cache_unfinished_req(self, req: Req, **kwargs):
26
48
  pass
27
49
 
28
50
  @abstractmethod
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
49
71
  def pretty_print(self):
50
72
  raise NotImplementedError()
51
73
 
74
+ def init_load_back(
75
+ self,
76
+ last_host_node: Any,
77
+ host_hit_length: int,
78
+ ) -> Tuple[torch.Tensor, Any]:
79
+ """
80
+ Preparing KV cache loading from host to device.
81
+ """
82
+ raise NotImplementedError()
83
+
84
+ def ready_to_load_host_cache(self) -> Any:
85
+ """
86
+ Notify the cache controller to start the KV cache loading
87
+ """
88
+ raise NotImplementedError()
89
+
90
+ def check_hicache_events(self) -> Any:
91
+ """
92
+ Check HiCache related activities to update radix tree and synchronize across TP workers if needed
93
+ """
94
+ raise NotImplementedError()
95
+
52
96
  def take_events(self):
53
97
  return []
@@ -2,40 +2,38 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any, Callable, List, Tuple
5
+ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  import torch
8
8
 
9
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
10
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from sglang.srt.managers.schedule_batch import Req
14
15
 
15
16
 
16
- class ChunkCacheEntry:
17
- def __init__(self, rid: str, value: torch.Tensor):
18
- self.rid = rid
19
- self.value = value
20
-
21
-
22
17
  class ChunkCache(BasePrefixCache):
23
18
  def __init__(
24
19
  self,
25
20
  req_to_token_pool: ReqToTokenPool,
26
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
21
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
27
22
  page_size: int,
28
23
  ):
29
24
  self.req_to_token_pool = req_to_token_pool
30
25
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31
26
  self.page_size = page_size
32
- self.disable = True
33
27
 
34
28
  def reset(self):
35
29
  pass
36
30
 
37
- def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
38
- return [], None
31
+ def match_prefix(self, **unused_kwargs) -> MatchResult:
32
+ return MatchResult(
33
+ device_indices=torch.empty((0,), dtype=torch.int64),
34
+ last_device_node=None,
35
+ last_host_node=None,
36
+ )
39
37
 
40
38
  def cache_finished_req(self, req: Req):
41
39
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -54,9 +52,6 @@ class ChunkCache(BasePrefixCache):
54
52
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
55
53
  req.prefix_indices = kv_indices
56
54
 
57
- def insert(self):
58
- raise NotImplementedError()
59
-
60
55
  def evict(self, num_tokens: int):
61
56
  pass
62
57
 
@@ -7,11 +7,12 @@ from typing import List, Optional
7
7
  import torch
8
8
 
9
9
  from sglang.srt.managers.cache_controller import HiCacheController
10
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
11
+ from sglang.srt.mem_cache.base_prefix_cache import MatchResult
10
12
  from sglang.srt.mem_cache.memory_pool import (
11
13
  MHATokenToKVPool,
12
14
  MLATokenToKVPool,
13
15
  ReqToTokenPool,
14
- TokenToKVPoolAllocator,
15
16
  )
16
17
  from sglang.srt.mem_cache.memory_pool_host import (
17
18
  MHATokenToKVPoolHost,
@@ -27,7 +28,7 @@ class HiRadixCache(RadixCache):
27
28
  def __init__(
28
29
  self,
29
30
  req_to_token_pool: ReqToTokenPool,
30
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
31
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
31
32
  tp_cache_group: torch.distributed.ProcessGroup,
32
33
  page_size: int,
33
34
  hicache_ratio: float,
@@ -283,39 +284,44 @@ class HiRadixCache(RadixCache):
283
284
  def init_load_back(
284
285
  self,
285
286
  last_node: TreeNode,
286
- prefix_indices: torch.Tensor,
287
+ host_hit_length: int,
287
288
  mem_quota: Optional[int] = None,
288
289
  ):
289
- assert (
290
- len(prefix_indices) == 0 or prefix_indices.is_cuda
291
- ), "indices of device kV caches should be on GPU"
290
+ _ = host_hit_length # unused, but kept for compatibility
292
291
  if last_node.evicted:
293
292
  loading_values = self.load_back(last_node, mem_quota)
294
293
  if loading_values is not None:
295
- prefix_indices = (
296
- loading_values
297
- if len(prefix_indices) == 0
298
- else torch.cat([prefix_indices, loading_values])
299
- )
300
294
  logger.debug(
301
295
  f"loading back {len(loading_values)} tokens for node {last_node.id}"
302
296
  )
297
+ return loading_values, last_node
303
298
 
304
299
  while last_node.evicted:
305
300
  last_node = last_node.parent
306
301
 
307
- return last_node, prefix_indices
302
+ return (
303
+ torch.empty((0,), dtype=torch.int64, device=self.device),
304
+ last_node,
305
+ )
308
306
 
309
- def ready_to_load_cache(self):
307
+ def ready_to_load_host_cache(self):
308
+ producer_index = self.cache_controller.layer_done_counter.next_producer()
310
309
  self.load_cache_event.set()
310
+ return producer_index
311
311
 
312
- def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
312
+ def check_hicache_events(self):
313
+ self.writing_check()
314
+ self.loading_check()
315
+
316
+ def match_prefix(self, key: List[int], **kwargs):
313
317
  empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
314
318
  if self.disable or len(key) == 0:
315
- if include_evicted:
316
- return empty_value, self.root_node, self.root_node
317
- else:
318
- return empty_value, self.root_node
319
+ return MatchResult(
320
+ device_indices=empty_value,
321
+ last_device_node=self.root_node,
322
+ last_host_node=self.root_node,
323
+ host_hit_length=0,
324
+ )
319
325
 
320
326
  if self.page_size != 1:
321
327
  page_aligned_len = len(key) // self.page_size * self.page_size
@@ -327,14 +333,18 @@ class HiRadixCache(RadixCache):
327
333
  else:
328
334
  value = empty_value
329
335
 
330
- last_node_global = last_node
336
+ host_hit_length = 0
337
+ last_host_node = last_node
331
338
  while last_node.evicted:
339
+ host_hit_length += len(last_node.host_value)
332
340
  last_node = last_node.parent
333
341
 
334
- if include_evicted:
335
- return value, last_node, last_node_global
336
- else:
337
- return value, last_node
342
+ return MatchResult(
343
+ device_indices=value,
344
+ last_device_node=last_node,
345
+ last_host_node=last_host_node,
346
+ host_hit_length=host_hit_length,
347
+ )
338
348
 
339
349
  def _match_prefix_helper(self, node: TreeNode, key: List):
340
350
  node.last_access_time = time.monotonic()
@@ -372,6 +382,7 @@ class HiRadixCache(RadixCache):
372
382
  new_node.lock_ref = child.lock_ref
373
383
  new_node.key = child.key[:split_len]
374
384
  new_node.loading = child.loading
385
+ new_node.hit_count = child.hit_count
375
386
 
376
387
  # split value and host value if exists
377
388
  if child.evicted: