sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,380 @@
1
+ import abc
2
+ import logging
3
+ import threading
4
+ from enum import IntEnum
5
+ from functools import wraps
6
+
7
+ import psutil
8
+ import torch
9
+
10
+ from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
+ from sglang.srt.utils import debug_timing
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class MemoryStateInt(IntEnum):
17
+ IDLE = 0
18
+ RESERVED = 1
19
+ PROTECTED = 2
20
+ SYNCED = 3
21
+ BACKUP = 4
22
+
23
+
24
+ def synchronized(debug_only=False):
25
+ def _decorator(func):
26
+ @wraps(func)
27
+ def wrapper(self, *args, **kwargs):
28
+ if (not debug_only) or self.debug:
29
+ return func(self, *args, **kwargs)
30
+ with self.lock:
31
+ return func(self, *args, **kwargs)
32
+ else:
33
+ return True
34
+
35
+ return wrapper
36
+
37
+ return _decorator
38
+
39
+
40
+ class HostKVCache(abc.ABC):
41
+
42
+ def __init__(
43
+ self,
44
+ device_pool: KVCache,
45
+ host_to_device_ratio: float,
46
+ host_size: int,
47
+ pin_memory: bool,
48
+ device: str,
49
+ page_size: int,
50
+ ):
51
+ self.device_pool = device_pool
52
+ self.dtype = device_pool.store_dtype
53
+ self.pin_memory = pin_memory
54
+ self.device = device
55
+ self.page_size = page_size
56
+ self.size_per_token = self.get_size_per_token()
57
+ if host_size > 0:
58
+ self.size = int(host_size * 1e9 // self.size_per_token)
59
+ else:
60
+ self.size = int(device_pool.size * host_to_device_ratio)
61
+ # Align the host memory pool size to the page size
62
+ self.size = self.size - (self.size % self.page_size)
63
+ self.start_layer = device_pool.start_layer
64
+ self.end_layer = device_pool.end_layer
65
+
66
+ assert (
67
+ self.size > device_pool.size
68
+ ), "The host memory should be larger than the device memory with the current protocol"
69
+
70
+ # Verify there is enough available host memory.
71
+ host_mem = psutil.virtual_memory()
72
+ requested_bytes = self.size * self.size_per_token
73
+ # preserve at least 10GB for other usage
74
+ ten_gb = 10 * (1024**3)
75
+ if requested_bytes > host_mem.available - ten_gb:
76
+ raise ValueError(
77
+ f"Not enough host memory available. Requesting "
78
+ f"{requested_bytes / 1e9:.2f} GB but only have "
79
+ f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
80
+ f"size of the hierarchical cache."
81
+ )
82
+ else:
83
+ logger.info(
84
+ f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
85
+ )
86
+
87
+ self.kv_buffer = self.init_kv_buffer()
88
+
89
+ # A lock for synchronized operations on memory allocation and state transitions.
90
+ self.lock = threading.RLock()
91
+ self.debug = logger.isEnabledFor(logging.DEBUG)
92
+ self.clear()
93
+
94
+ @abc.abstractmethod
95
+ def get_size_per_token(self):
96
+ raise NotImplementedError()
97
+
98
+ @abc.abstractmethod
99
+ def init_kv_buffer(self):
100
+ raise NotImplementedError()
101
+
102
+ @abc.abstractmethod
103
+ def transfer(self, indices, flat_data):
104
+ raise NotImplementedError()
105
+
106
+ @abc.abstractmethod
107
+ def get_flat_data(self, indices):
108
+ raise NotImplementedError()
109
+
110
+ @abc.abstractmethod
111
+ def get_flat_data_by_layer(self, indices, layer_id):
112
+ raise NotImplementedError()
113
+
114
+ @abc.abstractmethod
115
+ def assign_flat_data(self, indices, flat_data):
116
+ raise NotImplementedError()
117
+
118
+ @synchronized()
119
+ def clear(self):
120
+ # Initialize memory states and tracking structures.
121
+ self.mem_state = torch.zeros(
122
+ (self.size,), dtype=torch.uint8, device=self.device
123
+ )
124
+ self.free_slots = torch.arange(self.size, dtype=torch.int64)
125
+
126
+ def available_size(self):
127
+ return len(self.free_slots)
128
+
129
+ @synchronized()
130
+ def alloc(self, need_size: int) -> torch.Tensor:
131
+ if need_size > self.available_size():
132
+ return None
133
+
134
+ select_index = self.free_slots[:need_size]
135
+ self.free_slots = self.free_slots[need_size:]
136
+
137
+ if self.debug:
138
+ self.mem_state[select_index] = MemoryStateInt.RESERVED
139
+
140
+ return select_index
141
+
142
+ @synchronized()
143
+ def free(self, indices: torch.Tensor) -> int:
144
+ self.free_slots = torch.cat([self.free_slots, indices])
145
+ if self.debug:
146
+ self.mem_state[indices] = MemoryStateInt.IDLE
147
+ return len(indices)
148
+
149
+ @synchronized(debug_only=True)
150
+ def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
151
+ assert len(indices) > 0, "The indices should not be empty"
152
+ states = self.mem_state[indices]
153
+ assert (
154
+ states == states[0]
155
+ ).all(), "The memory slots should have the same state {}".format(states)
156
+ return MemoryStateInt(states[0].item())
157
+
158
+ @synchronized(debug_only=True)
159
+ def is_reserved(self, indices: torch.Tensor) -> bool:
160
+ return self.get_state(indices) == MemoryStateInt.RESERVED
161
+
162
+ @synchronized(debug_only=True)
163
+ def is_protected(self, indices: torch.Tensor) -> bool:
164
+ return self.get_state(indices) == MemoryStateInt.PROTECTED
165
+
166
+ @synchronized(debug_only=True)
167
+ def is_synced(self, indices: torch.Tensor) -> bool:
168
+ return self.get_state(indices) == MemoryStateInt.SYNCED
169
+
170
+ @synchronized(debug_only=True)
171
+ def is_backup(self, indices: torch.Tensor) -> bool:
172
+ return self.get_state(indices) == MemoryStateInt.BACKUP
173
+
174
+ @synchronized(debug_only=True)
175
+ def update_backup(self, indices: torch.Tensor):
176
+ if not self.is_synced(indices):
177
+ raise ValueError(
178
+ f"The host memory slots should be in SYNCED state before turning into BACKUP. "
179
+ f"Current state: {self.get_state(indices)}"
180
+ )
181
+ self.mem_state[indices] = MemoryStateInt.BACKUP
182
+
183
+ @synchronized(debug_only=True)
184
+ def update_synced(self, indices: torch.Tensor):
185
+ self.mem_state[indices] = MemoryStateInt.SYNCED
186
+
187
+ @synchronized(debug_only=True)
188
+ def protect_write(self, indices: torch.Tensor):
189
+ if not self.is_reserved(indices):
190
+ raise ValueError(
191
+ f"The host memory slots should be RESERVED before write operations. "
192
+ f"Current state: {self.get_state(indices)}"
193
+ )
194
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
195
+
196
+ @synchronized(debug_only=True)
197
+ def protect_load(self, indices: torch.Tensor):
198
+ if not self.is_backup(indices):
199
+ raise ValueError(
200
+ f"The host memory slots should be in BACKUP state before load operations. "
201
+ f"Current state: {self.get_state(indices)}"
202
+ )
203
+ self.mem_state[indices] = MemoryStateInt.PROTECTED
204
+
205
+ @synchronized(debug_only=True)
206
+ def complete_io(self, indices: torch.Tensor):
207
+ if not self.is_protected(indices):
208
+ raise ValueError(
209
+ f"The host memory slots should be PROTECTED during I/O operations. "
210
+ f"Current state: {self.get_state(indices)}"
211
+ )
212
+ self.mem_state[indices] = MemoryStateInt.SYNCED
213
+
214
+
215
+ class MHATokenToKVPoolHost(HostKVCache):
216
+ device_pool: MHATokenToKVPool
217
+
218
+ def __init__(
219
+ self,
220
+ device_pool: MHATokenToKVPool,
221
+ host_to_device_ratio: float,
222
+ host_size: int,
223
+ page_size: int,
224
+ pin_memory: bool = True,
225
+ device: str = "cpu",
226
+ ):
227
+ super().__init__(
228
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
229
+ )
230
+
231
+ def get_size_per_token(self):
232
+ self.head_num = self.device_pool.head_num
233
+ self.head_dim = self.device_pool.head_dim
234
+ self.layer_num = self.device_pool.layer_num
235
+
236
+ return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
237
+
238
+ def init_kv_buffer(self):
239
+ return torch.empty(
240
+ (2, self.layer_num, self.size, self.head_num, self.head_dim),
241
+ dtype=self.dtype,
242
+ device=self.device,
243
+ pin_memory=self.pin_memory,
244
+ )
245
+
246
+ @debug_timing
247
+ def transfer(self, indices, flat_data):
248
+ # backup prepared data from device to host
249
+ self.kv_buffer[:, :, indices] = flat_data.to(
250
+ device=self.device, non_blocking=False
251
+ )
252
+
253
+ def get_flat_data(self, indices):
254
+ return self.kv_buffer[:, :, indices]
255
+
256
+ def get_flat_data_by_layer(self, indices, layer_id):
257
+ return self.kv_buffer[:, layer_id - self.start_layer, indices]
258
+
259
+ def assign_flat_data(self, indices, flat_data):
260
+ self.kv_buffer[:, :, indices] = flat_data
261
+
262
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
263
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
264
+ for i in range(len(device_indices_cpu)):
265
+ h_index = host_indices[i * self.page_size]
266
+ d_index = device_indices_cpu[i]
267
+ for j in range(self.layer_num):
268
+ self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
269
+ device_pool.k_buffer[j][d_index : d_index + self.page_size],
270
+ non_blocking=True,
271
+ )
272
+ self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
273
+ device_pool.v_buffer[j][d_index : d_index + self.page_size],
274
+ non_blocking=True,
275
+ )
276
+
277
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
278
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
279
+ for i in range(len(device_indices_cpu)):
280
+ h_index = host_indices[i * self.page_size]
281
+ d_index = device_indices_cpu[i]
282
+ device_pool.k_buffer[layer_id - self.start_layer][
283
+ d_index : d_index + self.page_size
284
+ ].copy_(
285
+ self.kv_buffer[
286
+ 0, layer_id - self.start_layer, h_index : h_index + self.page_size
287
+ ],
288
+ non_blocking=True,
289
+ )
290
+ device_pool.v_buffer[layer_id - self.start_layer][
291
+ d_index : d_index + self.page_size
292
+ ].copy_(
293
+ self.kv_buffer[
294
+ 1, layer_id - self.start_layer, h_index : h_index + self.page_size
295
+ ],
296
+ non_blocking=True,
297
+ )
298
+
299
+
300
+ class MLATokenToKVPoolHost(HostKVCache):
301
+ device_pool: MLATokenToKVPool
302
+
303
+ def __init__(
304
+ self,
305
+ device_pool: MLATokenToKVPool,
306
+ host_to_device_ratio: float,
307
+ host_size: int,
308
+ page_size: int,
309
+ pin_memory: bool = True,
310
+ device: str = "cpu",
311
+ ):
312
+ super().__init__(
313
+ device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
314
+ )
315
+
316
+ def get_size_per_token(self):
317
+ self.kv_lora_rank = self.device_pool.kv_lora_rank
318
+ self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
319
+ self.layer_num = self.device_pool.layer_num
320
+
321
+ return (
322
+ (self.kv_lora_rank + self.qk_rope_head_dim)
323
+ * 1
324
+ * self.dtype.itemsize
325
+ * self.layer_num
326
+ )
327
+
328
+ def init_kv_buffer(self):
329
+ return torch.empty(
330
+ (
331
+ self.layer_num,
332
+ self.size,
333
+ 1,
334
+ self.kv_lora_rank + self.qk_rope_head_dim,
335
+ ),
336
+ dtype=self.dtype,
337
+ device=self.device,
338
+ pin_memory=self.pin_memory,
339
+ )
340
+
341
+ @debug_timing
342
+ def transfer(self, indices, flat_data):
343
+ # backup prepared data from device to host
344
+ self.kv_buffer[:, indices] = flat_data.to(
345
+ device=self.device, non_blocking=False
346
+ )
347
+
348
+ def get_flat_data(self, indices):
349
+ return self.kv_buffer[:, indices]
350
+
351
+ def get_flat_data_by_layer(self, indices, layer_id):
352
+ return self.kv_buffer[layer_id - self.start_layer, indices]
353
+
354
+ def assign_flat_data(self, indices, flat_data):
355
+ self.kv_buffer[:, indices] = flat_data
356
+
357
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
358
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
359
+ for i in range(len(device_indices_cpu)):
360
+ h_index = host_indices[i * self.page_size]
361
+ d_index = device_indices_cpu[i]
362
+ for j in range(self.layer_num):
363
+ self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
364
+ device_pool.kv_buffer[j][d_index : d_index + self.page_size],
365
+ non_blocking=True,
366
+ )
367
+
368
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
369
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
370
+ for i in range(len(device_indices_cpu)):
371
+ h_index = host_indices[i * self.page_size]
372
+ d_index = device_indices_cpu[i]
373
+ device_pool.kv_buffer[layer_id - self.start_layer][
374
+ d_index : d_index + self.page_size
375
+ ].copy_(
376
+ self.kv_buffer[
377
+ layer_id - self.start_layer, h_index : h_index + self.page_size
378
+ ],
379
+ non_blocking=True,
380
+ )
@@ -23,7 +23,7 @@ import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
25
  from functools import partial
26
- from typing import TYPE_CHECKING, List, Optional, Tuple
26
+ from typing import TYPE_CHECKING, List, Optional
27
27
 
28
28
  import torch
29
29
 
@@ -31,11 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
31
31
  AllBlocksCleared,
32
32
  BlockRemoved,
33
33
  BlockStored,
34
- KVCacheEvent,
35
34
  )
36
- from sglang.srt.managers.schedule_batch import global_server_args_dict
37
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
38
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
35
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
36
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
37
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
38
 
40
39
  if TYPE_CHECKING:
41
40
  from sglang.srt.managers.schedule_batch import Req
@@ -47,9 +46,9 @@ class TreeNode:
47
46
 
48
47
  def __init__(self, id: Optional[int] = None):
49
48
  self.children = defaultdict(TreeNode)
50
- self.parent = None
51
- self.key = None
52
- self.value = None
49
+ self.parent: TreeNode = None
50
+ self.key: List[int] = None
51
+ self.value: Optional[torch.Tensor] = None
53
52
  self.lock_ref = 0
54
53
  self.last_access_time = time.monotonic()
55
54
 
@@ -57,7 +56,7 @@ class TreeNode:
57
56
  # indicating the node is loading KV cache from host
58
57
  self.loading = False
59
58
  # store the host indices of KV cache
60
- self.host_value = None
59
+ self.host_value: Optional[torch.Tensor] = None
61
60
 
62
61
  self.id = TreeNode.counter if id is None else id
63
62
  TreeNode.counter += 1
@@ -99,7 +98,7 @@ class RadixCache(BasePrefixCache):
99
98
  def __init__(
100
99
  self,
101
100
  req_to_token_pool: ReqToTokenPool,
102
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
101
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
103
102
  page_size: int,
104
103
  disable: bool = False,
105
104
  enable_kv_cache_events: bool = False,
@@ -135,7 +134,7 @@ class RadixCache(BasePrefixCache):
135
134
  self.protected_size_ = 0
136
135
  self._record_all_cleared_event()
137
136
 
138
- def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
137
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
139
138
  """Find the matching prefix from the radix tree.
140
139
  Args:
141
140
  key: A list of token IDs to find a matching prefix.
@@ -147,13 +146,14 @@ class RadixCache(BasePrefixCache):
147
146
  than the last node's value.
148
147
  """
149
148
  if self.disable or len(key) == 0:
150
- return (
151
- torch.empty(
149
+ return MatchResult(
150
+ device_indices=torch.empty(
152
151
  (0,),
153
152
  dtype=torch.int64,
154
153
  device=self.device,
155
154
  ),
156
- self.root_node,
155
+ last_device_node=self.root_node,
156
+ last_host_node=self.root_node,
157
157
  )
158
158
 
159
159
  if self.page_size != 1:
@@ -165,7 +165,11 @@ class RadixCache(BasePrefixCache):
165
165
  value = torch.cat(value)
166
166
  else:
167
167
  value = torch.empty((0,), dtype=torch.int64, device=self.device)
168
- return value, last_node
168
+ return MatchResult(
169
+ device_indices=value,
170
+ last_device_node=last_node,
171
+ last_host_node=last_node,
172
+ )
169
173
 
170
174
  def insert(self, key: List, value=None):
171
175
  if self.disable:
@@ -235,7 +239,7 @@ class RadixCache(BasePrefixCache):
235
239
  )
236
240
 
237
241
  # The prefix indices could be updated, reuse it
238
- new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
242
+ new_indices, new_last_node, _, _ = self.match_prefix(page_aligned_token_ids)
239
243
  self.req_to_token_pool.write(
240
244
  (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
241
245
  new_indices[len(req.prefix_indices) :],
@@ -461,23 +465,47 @@ class RadixCache(BasePrefixCache):
461
465
  return ret_list
462
466
 
463
467
  def _record_store_event(self, node: TreeNode):
468
+ # One BlockStored per ``page_size`` chunk.
464
469
  if self.enable_kv_cache_events:
465
- block_hash = hash(tuple(node.key))
466
- parent_block_hash = hash(tuple(node.parent.key))
467
- self.kv_event_queue.append(
468
- BlockStored(
469
- block_hashes=[block_hash],
470
- parent_block_hash=parent_block_hash,
471
- token_ids=node.key,
472
- block_size=len(node.key),
473
- lora_id=None,
470
+ # First chunk links to the last page of the parent node (if any).
471
+ if node.parent is None:
472
+ parent_block_hash = None
473
+ else:
474
+ last_page_start = (
475
+ (len(node.parent.key) - 1) // self.page_size
476
+ ) * self.page_size
477
+ parent_parent_tokens = node.parent.key[last_page_start:]
478
+ parent_block_hash = hash(tuple(parent_parent_tokens))
479
+
480
+ for start in range(0, len(node.key), self.page_size):
481
+ page_tokens = node.key[start : start + self.page_size]
482
+ if not page_tokens:
483
+ continue
484
+
485
+ block_hash = hash(tuple(page_tokens))
486
+
487
+ self.kv_event_queue.append(
488
+ BlockStored(
489
+ block_hashes=[block_hash],
490
+ parent_block_hash=parent_block_hash,
491
+ token_ids=page_tokens,
492
+ block_size=len(page_tokens),
493
+ lora_id=None,
494
+ )
474
495
  )
475
- )
496
+
497
+ # Chain next chunk to this one.
498
+ parent_block_hash = block_hash
476
499
 
477
500
  def _record_remove_event(self, node: TreeNode):
501
+ # One BlockRemoved per chunk.
478
502
  if self.enable_kv_cache_events:
479
- block_hash = hash(tuple(node.key))
480
- self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
503
+ for start in range(0, len(node.key), self.page_size):
504
+ page_tokens = node.key[start : start + self.page_size]
505
+ if not page_tokens:
506
+ continue
507
+ block_hash = hash(tuple(page_tokens))
508
+ self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
481
509
 
482
510
  def _record_all_cleared_event(self):
483
511
  if self.enable_kv_cache_events: