sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
20
20
  import torch
21
21
 
22
22
  from sglang.srt.configs.model_config import ModelConfig
23
- from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
23
+ from sglang.srt.distributed import get_pp_group, get_world_group
24
24
  from sglang.srt.hf_transformers_utils import (
25
25
  get_processor,
26
26
  get_tokenizer,
@@ -183,8 +183,11 @@ class TpModelWorker:
183
183
  def forward_batch_generation(
184
184
  self,
185
185
  model_worker_batch: ModelWorkerBatch,
186
+ launch_done: Optional[threading.Event] = None,
186
187
  skip_sample: bool = False,
187
- ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
188
+ ) -> Tuple[
189
+ Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
190
+ ]:
188
191
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
189
192
 
190
193
  pp_proxy_tensors = None
@@ -196,11 +199,11 @@ class TpModelWorker:
196
199
  )
197
200
 
198
201
  if self.pp_group.is_last_rank:
199
- logits_output = self.model_runner.forward(
202
+ logits_output, can_run_cuda_graph = self.model_runner.forward(
200
203
  forward_batch, pp_proxy_tensors=pp_proxy_tensors
201
204
  )
202
- if model_worker_batch.launch_done is not None:
203
- model_worker_batch.launch_done.set()
205
+ if launch_done is not None:
206
+ launch_done.set()
204
207
 
205
208
  if skip_sample:
206
209
  next_token_ids = None
@@ -209,17 +212,17 @@ class TpModelWorker:
209
212
  logits_output, model_worker_batch
210
213
  )
211
214
 
212
- return logits_output, next_token_ids
215
+ return logits_output, next_token_ids, can_run_cuda_graph
213
216
  else:
214
- pp_proxy_tensors = self.model_runner.forward(
217
+ pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
215
218
  forward_batch,
216
219
  pp_proxy_tensors=pp_proxy_tensors,
217
220
  )
218
- return pp_proxy_tensors.tensors, None
221
+ return pp_proxy_tensors.tensors, None, can_run_cuda_graph
219
222
 
220
223
  def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
221
224
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
222
- logits_output = self.model_runner.forward(forward_batch)
225
+ logits_output, _ = self.model_runner.forward(forward_batch)
223
226
  embeddings = logits_output.embeddings
224
227
  return embeddings
225
228
 
@@ -18,7 +18,7 @@ import logging
18
18
  import signal
19
19
  import threading
20
20
  from queue import Queue
21
- from typing import Optional
21
+ from typing import Optional, Tuple
22
22
 
23
23
  import psutil
24
24
  import torch
@@ -127,10 +127,12 @@ class TpModelWorkerClient:
127
127
  batch_lists = [None] * 2
128
128
 
129
129
  while True:
130
- model_worker_batch, future_token_ids_ct = self.input_queue.get()
130
+ model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
131
131
  if not model_worker_batch:
132
132
  break
133
133
 
134
+ sync_event.wait()
135
+
134
136
  # Keep a reference of model_worker_batch by storing it into a list.
135
137
  # Otherwise, the tensor members of model_worker_batch will be released
136
138
  # by pytorch and cause CUDA illegal memory access errors.
@@ -145,8 +147,10 @@ class TpModelWorkerClient:
145
147
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
146
148
 
147
149
  # Run forward
148
- logits_output, next_token_ids = self.worker.forward_batch_generation(
149
- model_worker_batch
150
+ logits_output, next_token_ids, can_run_cuda_graph = (
151
+ self.worker.forward_batch_generation(
152
+ model_worker_batch, model_worker_batch.launch_done
153
+ )
150
154
  )
151
155
 
152
156
  # Update the future token ids map
@@ -171,14 +175,18 @@ class TpModelWorkerClient:
171
175
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
172
176
  copy_done.record()
173
177
 
174
- self.output_queue.put((copy_done, logits_output, next_token_ids))
178
+ self.output_queue.put(
179
+ (copy_done, logits_output, next_token_ids, can_run_cuda_graph)
180
+ )
175
181
 
176
182
  def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
177
183
  """
178
184
  This function is called to resolve the last batch result and
179
185
  wait for the current batch to be launched. Used in overlap mode.
180
186
  """
181
- copy_done, logits_output, next_token_ids = self.output_queue.get()
187
+ copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
188
+ self.output_queue.get()
189
+ )
182
190
 
183
191
  if launch_done is not None:
184
192
  launch_done.wait()
@@ -193,9 +201,11 @@ class TpModelWorkerClient:
193
201
  logits_output.input_token_logprobs.tolist()
194
202
  )
195
203
  next_token_ids = next_token_ids.tolist()
196
- return logits_output, next_token_ids
204
+ return logits_output, next_token_ids, can_run_cuda_graph
197
205
 
198
- def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
206
+ def forward_batch_generation(
207
+ self, model_worker_batch: ModelWorkerBatch
208
+ ) -> Tuple[None, torch.Tensor, bool]:
199
209
  # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
200
210
  sampling_info = model_worker_batch.sampling_info
201
211
  sampling_info.update_penalties()
@@ -206,10 +216,11 @@ class TpModelWorkerClient:
206
216
  )
207
217
 
208
218
  # A cuda stream sync here to avoid the cuda illegal memory access error.
209
- self.scheduler_stream.synchronize()
219
+ sync_event = torch.get_device_module(self.device).Event()
220
+ sync_event.record(self.scheduler_stream)
210
221
 
211
222
  # Push a new batch to the queue
212
- self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
223
+ self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
213
224
 
214
225
  # Allocate output future objects
215
226
  bs = len(model_worker_batch.seq_lens)
@@ -223,7 +234,7 @@ class TpModelWorkerClient:
223
234
  self.future_token_ids_ct = (
224
235
  self.future_token_ids_ct + bs
225
236
  ) % self.future_token_ids_limit
226
- return None, future_next_token_ids
237
+ return None, future_next_token_ids, False
227
238
 
228
239
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
229
240
  success, message = self.worker.update_weights_from_disk(recv_req)
@@ -48,3 +48,6 @@ class BasePrefixCache(ABC):
48
48
 
49
49
  def pretty_print(self):
50
50
  raise NotImplementedError()
51
+
52
+ def take_events(self):
53
+ return []
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
38
38
 
39
39
  def cache_finished_req(self, req: Req):
40
40
  kv_indices = self.req_to_token_pool.req_to_token[
41
- req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
41
+ req.req_pool_idx,
42
+ # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
43
+ : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
42
44
  ]
43
45
  self.req_to_token_pool.free(req.req_pool_idx)
44
46
  self.token_to_kv_pool_allocator.free(kv_indices)
@@ -335,13 +335,13 @@ class HiRadixCache(RadixCache):
335
335
  return value, last_node
336
336
 
337
337
  def _match_prefix_helper(self, node: TreeNode, key: List):
338
- node.last_access_time = time.time()
338
+ node.last_access_time = time.monotonic()
339
339
  child_key = self.get_child_key_fn(key)
340
340
  value = []
341
341
 
342
342
  while len(key) > 0 and child_key in node.children.keys():
343
343
  child = node.children[child_key]
344
- child.last_access_time = time.time()
344
+ child.last_access_time = time.monotonic()
345
345
  prefix_len = self.key_match_fn(child.key, key)
346
346
  if prefix_len < len(child.key):
347
347
  new_node = self._split_node(child.key, child, prefix_len)
@@ -386,7 +386,7 @@ class HiRadixCache(RadixCache):
386
386
  return new_node
387
387
 
388
388
  def _insert_helper(self, node: TreeNode, key: List, value):
389
- node.last_access_time = time.time()
389
+ node.last_access_time = time.monotonic()
390
390
  if len(key) == 0:
391
391
  return 0
392
392
 
@@ -395,7 +395,7 @@ class HiRadixCache(RadixCache):
395
395
 
396
396
  while len(key) > 0 and child_key in node.children.keys():
397
397
  node = node.children[child_key]
398
- node.last_access_time = time.time()
398
+ node.last_access_time = time.monotonic()
399
399
  prefix_len = self.key_match_fn(node.key, key)
400
400
 
401
401
  if prefix_len == len(node.key):
@@ -38,11 +38,17 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
- from sglang.srt.utils import debug_timing, get_compiler_backend
41
+ from sglang.srt.utils import (
42
+ debug_timing,
43
+ get_compiler_backend,
44
+ is_cuda,
45
+ next_power_of_2,
46
+ )
42
47
 
43
48
  logger = logging.getLogger(__name__)
44
49
 
45
50
  GB = 1024 * 1024 * 1024
51
+ _is_cuda = is_cuda()
46
52
 
47
53
 
48
54
  class ReqToTokenPool:
@@ -94,6 +100,33 @@ class ReqToTokenPool:
94
100
 
95
101
 
96
102
  class KVCache(abc.ABC):
103
+ @abc.abstractmethod
104
+ def __init__(
105
+ self,
106
+ size: int,
107
+ page_size: int,
108
+ dtype: torch.dtype,
109
+ layer_num: int,
110
+ device: str,
111
+ enable_memory_saver: bool,
112
+ start_layer: Optional[int] = None,
113
+ end_layer: Optional[int] = None,
114
+ ):
115
+ self.size = size
116
+ self.page_size = page_size
117
+ self.dtype = dtype
118
+ self.device = device
119
+ if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
120
+ # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
121
+ self.store_dtype = torch.uint8
122
+ else:
123
+ self.store_dtype = dtype
124
+ self.layer_num = layer_num
125
+ self.start_layer = start_layer or 0
126
+ self.end_layer = end_layer or layer_num - 1
127
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
128
+ enable=enable_memory_saver
129
+ )
97
130
 
98
131
  @abc.abstractmethod
99
132
  def get_key_buffer(self, layer_id: int) -> torch.Tensor:
@@ -217,30 +250,24 @@ class MHATokenToKVPool(KVCache):
217
250
  start_layer: Optional[int] = None,
218
251
  end_layer: Optional[int] = None,
219
252
  ):
220
- self.size = size
221
- self.page_size = page_size
222
- self.dtype = dtype
223
- self.device = device
224
- if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
225
- # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
226
- self.store_dtype = torch.uint8
227
- else:
228
- self.store_dtype = dtype
229
- self.memory_saver_adapter = TorchMemorySaverAdapter.create(
230
- enable=enable_memory_saver
253
+ super().__init__(
254
+ size,
255
+ page_size,
256
+ dtype,
257
+ layer_num,
258
+ device,
259
+ enable_memory_saver,
260
+ start_layer,
261
+ end_layer,
231
262
  )
232
263
 
233
264
  self.head_num = head_num
234
265
  self.head_dim = head_dim
235
- self.layer_num = layer_num
236
266
  self._create_buffers()
237
- self.start_layer = start_layer or 0
238
- self.end_layer = end_layer or layer_num - 1
239
267
 
240
268
  self.layer_transfer_counter = None
241
- self.capture_mode = False
242
269
  self.device_module = torch.get_device_module(self.device)
243
- self.alt_stream = self.device_module.Stream()
270
+ self.alt_stream = self.device_module.Stream() if is_cuda else None
244
271
 
245
272
  k_size, v_size = self.get_kv_size_bytes()
246
273
  logger.info(
@@ -357,6 +384,8 @@ class MHATokenToKVPool(KVCache):
357
384
  k_scale: Optional[float] = None,
358
385
  v_scale: Optional[float] = None,
359
386
  ):
387
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
388
+
360
389
  layer_id = layer.layer_id
361
390
  if cache_k.dtype != self.dtype:
362
391
  if k_scale is not None:
@@ -370,7 +399,7 @@ class MHATokenToKVPool(KVCache):
370
399
  cache_k = cache_k.view(self.store_dtype)
371
400
  cache_v = cache_v.view(self.store_dtype)
372
401
 
373
- if self.capture_mode and cache_k.shape[0] < 4:
402
+ if get_is_capture_mode() and self.alt_stream is not None:
374
403
  # Overlap the copy of K and V cache for small batch size
375
404
  current_stream = self.device_module.current_stream()
376
405
  self.alt_stream.wait_stream(current_stream)
@@ -493,26 +522,21 @@ class MLATokenToKVPool(KVCache):
493
522
  start_layer: Optional[int] = None,
494
523
  end_layer: Optional[int] = None,
495
524
  ):
496
- self.size = size
497
- self.page_size = page_size
498
- self.dtype = dtype
499
- self.device = device
500
- if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
501
- # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
502
- self.store_dtype = torch.uint8
503
- else:
504
- self.store_dtype = dtype
525
+ super().__init__(
526
+ size,
527
+ page_size,
528
+ dtype,
529
+ layer_num,
530
+ device,
531
+ enable_memory_saver,
532
+ start_layer,
533
+ end_layer,
534
+ )
535
+
505
536
  self.kv_lora_rank = kv_lora_rank
506
537
  self.qk_rope_head_dim = qk_rope_head_dim
507
- self.layer_num = layer_num
508
- self.start_layer = start_layer or 0
509
- self.end_layer = end_layer or layer_num - 1
510
-
511
- memory_saver_adapter = TorchMemorySaverAdapter.create(
512
- enable=enable_memory_saver
513
- )
514
538
 
515
- with memory_saver_adapter.region():
539
+ with self.memory_saver_adapter.region():
516
540
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
517
541
  self.kv_buffer = [
518
542
  torch.zeros(
@@ -524,7 +548,6 @@ class MLATokenToKVPool(KVCache):
524
548
  ]
525
549
 
526
550
  self.layer_transfer_counter = None
527
- self.page_size = page_size
528
551
 
529
552
  kv_size = self.get_kv_size_bytes()
530
553
  logger.info(
@@ -637,20 +660,18 @@ class DoubleSparseTokenToKVPool(KVCache):
637
660
  start_layer: Optional[int] = None,
638
661
  end_layer: Optional[int] = None,
639
662
  ):
640
- self.size = size
641
- self.page_size = page_size
642
- self.dtype = dtype
643
- self.device = device
644
- if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
645
- # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
646
- self.store_dtype = torch.uint8
647
- else:
648
- self.store_dtype = dtype
649
- memory_saver_adapter = TorchMemorySaverAdapter.create(
650
- enable=enable_memory_saver
663
+ super().__init__(
664
+ size,
665
+ page_size,
666
+ dtype,
667
+ layer_num,
668
+ device,
669
+ enable_memory_saver,
670
+ start_layer,
671
+ end_layer,
651
672
  )
652
673
 
653
- with memory_saver_adapter.region():
674
+ with self.memory_saver_adapter.region():
654
675
  # [size, head_num, head_dim] for each layer
655
676
  self.k_buffer = [
656
677
  torch.zeros(
@@ -673,9 +694,6 @@ class DoubleSparseTokenToKVPool(KVCache):
673
694
  for _ in range(layer_num)
674
695
  ]
675
696
 
676
- self.start_layer = start_layer or 0
677
- self.end_layer = end_layer or layer_num - 1
678
-
679
697
  def get_key_buffer(self, layer_id: int):
680
698
  return self.k_buffer[layer_id - self.start_layer]
681
699
 
@@ -743,7 +761,7 @@ class HostKVCache(abc.ABC):
743
761
 
744
762
  def __init__(
745
763
  self,
746
- device_pool: MHATokenToKVPool,
764
+ device_pool: KVCache,
747
765
  host_to_device_ratio: float,
748
766
  host_size: int,
749
767
  pin_memory: bool,
@@ -762,6 +780,8 @@ class HostKVCache(abc.ABC):
762
780
  self.size = int(device_pool.size * host_to_device_ratio)
763
781
  # Align the host memory pool size to the page size
764
782
  self.size = self.size - (self.size % self.page_size)
783
+ self.start_layer = device_pool.start_layer
784
+ self.end_layer = device_pool.end_layer
765
785
 
766
786
  assert (
767
787
  self.size > device_pool.size
@@ -913,6 +933,8 @@ class HostKVCache(abc.ABC):
913
933
 
914
934
 
915
935
  class MHATokenToKVPoolHost(HostKVCache):
936
+ device_pool: MHATokenToKVPool
937
+
916
938
  def __init__(
917
939
  self,
918
940
  device_pool: MHATokenToKVPool,
@@ -996,6 +1018,8 @@ class MHATokenToKVPoolHost(HostKVCache):
996
1018
 
997
1019
 
998
1020
  class MLATokenToKVPoolHost(HostKVCache):
1021
+ device_pool: MLATokenToKVPool
1022
+
999
1023
  def __init__(
1000
1024
  self,
1001
1025
  device_pool: MLATokenToKVPool,
@@ -0,0 +1,45 @@
1
+ from typing import Dict
2
+
3
+ import torch
4
+
5
+
6
+ class MultiModalCache:
7
+ """MultiModalCache is used to store vlm encoder results"""
8
+
9
+ def __init__(
10
+ self,
11
+ max_size: int,
12
+ ):
13
+ self.max_size = max_size
14
+ self.mm_cache: Dict[int, torch.Tensor] = {}
15
+ self.current_size = 0
16
+
17
+ def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
18
+ if mm_hash in self.mm_cache:
19
+ return True
20
+ data_size = self._get_tensor_size(embedding)
21
+ if self.current_size + data_size > self.max_size:
22
+ return False
23
+ self.mm_cache[mm_hash] = embedding
24
+ self.current_size += data_size
25
+ return True
26
+
27
+ def get(self, mm_hash: int) -> torch.Tensor:
28
+ return self.mm_cache.get(mm_hash)
29
+
30
+ def free(self, mm_hash: int) -> bool:
31
+ if mm_hash not in self.mm_cache:
32
+ return False
33
+ old_embedding = self.mm_cache.pop(mm_hash)
34
+ self.current_size -= self._get_tensor_size(old_embedding)
35
+ return True
36
+
37
+ def clear(self):
38
+ self.mm_cache.clear()
39
+ self.current_size = 0
40
+
41
+ def _get_tensor_size(self, embedding: torch.Tensor):
42
+ return embedding.element_size() * embedding.numel()
43
+
44
+ def __len__(self):
45
+ return len(self.mm_cache)
@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
27
27
 
28
28
  import torch
29
29
 
30
+ from sglang.srt.disaggregation.kv_events import (
31
+ AllBlocksCleared,
32
+ BlockRemoved,
33
+ BlockStored,
34
+ KVCacheEvent,
35
+ )
30
36
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
37
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
32
38
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
@@ -45,7 +51,7 @@ class TreeNode:
45
51
  self.key = None
46
52
  self.value = None
47
53
  self.lock_ref = 0
48
- self.last_access_time = time.time()
54
+ self.last_access_time = time.monotonic()
49
55
 
50
56
  self.hit_count = 0
51
57
  # indicating the node is loading KV cache from host
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
96
102
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
97
103
  page_size: int,
98
104
  disable: bool = False,
105
+ enable_kv_cache_events: bool = False,
99
106
  ):
100
107
  self.req_to_token_pool = req_to_token_pool
101
108
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
102
109
  self.page_size = page_size
103
110
  self.disable = disable
111
+ self.enable_kv_cache_events = enable_kv_cache_events
112
+ self.kv_event_queue = []
104
113
 
105
114
  if self.token_to_kv_pool_allocator:
106
115
  self.device = self.token_to_kv_pool_allocator.device
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
124
133
  self.root_node.lock_ref = 1
125
134
  self.evictable_size_ = 0
126
135
  self.protected_size_ = 0
136
+ self._record_all_cleared_event()
127
137
 
128
138
  def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
129
139
  """Find the matching prefix from the radix tree.
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
273
283
  if len(x.parent.children) == 0:
274
284
  heapq.heappush(leaves, x.parent)
275
285
 
286
+ self._record_remove_event(x)
287
+
276
288
  def inc_lock_ref(self, node: TreeNode):
277
289
  if self.disable:
278
290
  return 0
@@ -322,14 +334,14 @@ class RadixCache(BasePrefixCache):
322
334
  ##### Internal Helper Functions #####
323
335
 
324
336
  def _match_prefix_helper(self, node: TreeNode, key: List):
325
- node.last_access_time = time.time()
337
+ node.last_access_time = time.monotonic()
326
338
 
327
339
  child_key = self.get_child_key_fn(key)
328
340
 
329
341
  value = []
330
342
  while len(key) > 0 and child_key in node.children.keys():
331
343
  child = node.children[child_key]
332
- child.last_access_time = time.time()
344
+ child.last_access_time = time.monotonic()
333
345
  prefix_len = self.key_match_fn(child.key, key)
334
346
  if prefix_len < len(child.key):
335
347
  new_node = self._split_node(child.key, child, prefix_len)
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
348
360
 
349
361
  def _split_node(self, key, child: TreeNode, split_len: int):
350
362
  # new_node -> child
363
+ self._record_remove_event(child)
351
364
  new_node = TreeNode()
352
365
  new_node.children = {self.get_child_key_fn(key[split_len:]): child}
353
366
  new_node.parent = child.parent
@@ -358,10 +371,14 @@ class RadixCache(BasePrefixCache):
358
371
  child.key = child.key[split_len:]
359
372
  child.value = child.value[split_len:]
360
373
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
374
+
375
+ self._record_store_event(new_node)
376
+ self._record_store_event(child)
377
+
361
378
  return new_node
362
379
 
363
380
  def _insert_helper(self, node: TreeNode, key: List, value):
364
- node.last_access_time = time.time()
381
+ node.last_access_time = time.monotonic()
365
382
  if len(key) == 0:
366
383
  return 0
367
384
 
@@ -370,7 +387,7 @@ class RadixCache(BasePrefixCache):
370
387
  total_prefix_length = 0
371
388
  while len(key) > 0 and child_key in node.children.keys():
372
389
  node = node.children[child_key]
373
- node.last_access_time = time.time()
390
+ node.last_access_time = time.monotonic()
374
391
  prefix_len = self.key_match_fn(node.key, key)
375
392
  total_prefix_length += prefix_len
376
393
  key = key[prefix_len:]
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
390
407
  new_node.value = value
391
408
  node.children[child_key] = new_node
392
409
  self.evictable_size_ += len(value)
410
+ self._record_store_event(new_node)
393
411
  return total_prefix_length
394
412
 
395
413
  def _print_helper(self, node: TreeNode, indent: int):
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
442
460
 
443
461
  return ret_list
444
462
 
463
+ def _record_store_event(self, node: TreeNode):
464
+ if self.enable_kv_cache_events:
465
+ block_hash = hash(tuple(node.key))
466
+ parent_block_hash = hash(tuple(node.parent.key))
467
+ self.kv_event_queue.append(
468
+ BlockStored(
469
+ block_hashes=[block_hash],
470
+ parent_block_hash=parent_block_hash,
471
+ token_ids=node.key,
472
+ block_size=len(node.key),
473
+ lora_id=None,
474
+ )
475
+ )
476
+
477
+ def _record_remove_event(self, node: TreeNode):
478
+ if self.enable_kv_cache_events:
479
+ block_hash = hash(tuple(node.key))
480
+ self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
481
+
482
+ def _record_all_cleared_event(self):
483
+ if self.enable_kv_cache_events:
484
+ self.kv_event_queue.append(AllBlocksCleared())
485
+
486
+ def take_events(self):
487
+ """Atomically takes all events and clears the queue.
488
+
489
+ Returns:
490
+ A list of KV cache events.
491
+ """
492
+ if not self.enable_kv_cache_events:
493
+ return []
494
+ events = self.kv_event_queue
495
+ self.kv_event_queue = []
496
+ return events
497
+
445
498
 
446
499
  if __name__ == "__main__":
447
500
  tree = RadixCache(None, None, page_size=1, disable=False)