sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
112
112
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
113
113
  from sglang.srt.mem_cache.radix_cache import RadixCache
114
114
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
115
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
115
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
116
116
  from sglang.srt.server_args import PortArgs, ServerArgs
117
117
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
118
118
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -379,7 +379,7 @@ class Scheduler(
379
379
  # Init profiler
380
380
  self.torch_profiler = None
381
381
  self.torch_profiler_output_dir: Optional[str] = None
382
- self.torch_profiler_activities: Optional[List[str]] = None
382
+ self.profiler_activities: Optional[List[str]] = None
383
383
  self.profiler_target_forward_ct: Optional[int] = None
384
384
 
385
385
  # Init metrics stats
@@ -1110,7 +1110,7 @@ class Scheduler(
1110
1110
  )
1111
1111
  if memory_leak:
1112
1112
  msg = (
1113
- "KV cache pool leak detected! "
1113
+ "token_to_kv_pool_allocator memory leak detected! "
1114
1114
  f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
1115
1115
  f"{self.token_to_kv_pool_allocator.available_size()=}\n"
1116
1116
  f"{self.tree_cache.evictable_size()=}\n"
@@ -1121,7 +1121,7 @@ class Scheduler(
1121
1121
 
1122
1122
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
1123
1123
  msg = (
1124
- "Memory pool leak detected!"
1124
+ "req_to_token_pool memory leak detected!"
1125
1125
  f"available_size={len(self.req_to_token_pool.free_slots)}, "
1126
1126
  f"total_size={self.req_to_token_pool.size}\n"
1127
1127
  )
@@ -1186,7 +1186,7 @@ class Scheduler(
1186
1186
  ret = None
1187
1187
 
1188
1188
  # Handle DP attention
1189
- if self.server_args.enable_dp_attention:
1189
+ if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
1190
1190
  ret, _ = self.prepare_dp_attn_batch(ret)
1191
1191
 
1192
1192
  return ret
@@ -1282,7 +1282,7 @@ class Scheduler(
1282
1282
  ]
1283
1283
 
1284
1284
  if self.enable_hierarchical_cache:
1285
- self.tree_cache.read_to_load_cache()
1285
+ self.tree_cache.ready_to_load_cache()
1286
1286
 
1287
1287
  if adder.new_chunked_req is not None:
1288
1288
  assert self.chunked_req is None
@@ -1703,18 +1703,12 @@ class Scheduler(
1703
1703
  def save_remote_model(self, params):
1704
1704
  url = params["url"]
1705
1705
 
1706
- if isinstance(self.tp_worker, TpModelWorkerClient):
1707
- worker = self.tp_worker.worker
1708
- else:
1709
- worker = self.tp_worker
1706
+ worker = self.tp_worker.worker
1710
1707
 
1711
1708
  worker.model_runner.save_remote_model(url)
1712
1709
 
1713
1710
  def save_sharded_model(self, params):
1714
- if isinstance(self.tp_worker, TpModelWorkerClient):
1715
- worker = self.tp_worker.worker
1716
- else:
1717
- worker = self.tp_worker
1711
+ worker = self.tp_worker.worker
1718
1712
 
1719
1713
  worker.model_runner.save_sharded_model(
1720
1714
  path=params["path"],
@@ -1813,7 +1807,11 @@ class Scheduler(
1813
1807
  def profile(self, recv_req: ProfileReq):
1814
1808
  if recv_req.type == ProfileReqType.START_PROFILE:
1815
1809
  return self.start_profile(
1816
- recv_req.output_dir, recv_req.num_steps, recv_req.activities
1810
+ recv_req.output_dir,
1811
+ recv_req.num_steps,
1812
+ recv_req.activities,
1813
+ recv_req.with_stack,
1814
+ recv_req.record_shapes,
1817
1815
  )
1818
1816
  else:
1819
1817
  return self.stop_profile()
@@ -1823,8 +1821,10 @@ class Scheduler(
1823
1821
  output_dir: Optional[str],
1824
1822
  num_steps: Optional[int],
1825
1823
  activities: Optional[List[str]],
1824
+ with_stack: Optional[bool],
1825
+ record_shapes: Optional[bool],
1826
1826
  ) -> None:
1827
- if self.torch_profiler_activities:
1827
+ if self.profiler_activities:
1828
1828
  return ProfileReqOutput(
1829
1829
  success=False,
1830
1830
  message="Profiling is already in progress. Call /stop_profile first.",
@@ -1836,7 +1836,7 @@ class Scheduler(
1836
1836
  activities = ["CPU", "GPU"]
1837
1837
 
1838
1838
  self.torch_profiler_output_dir = output_dir
1839
- self.torch_profiler_activities = activities
1839
+ self.profiler_activities = activities
1840
1840
  logger.info(
1841
1841
  "Profiling starts. Traces will be saved to: %s",
1842
1842
  self.torch_profiler_output_dir,
@@ -1853,13 +1853,17 @@ class Scheduler(
1853
1853
  if torchprof_activities:
1854
1854
  self.torch_profiler = torch.profiler.profile(
1855
1855
  activities=torchprof_activities,
1856
- with_stack=True,
1856
+ with_stack=with_stack if with_stack is not None else True,
1857
+ record_shapes=record_shapes if record_shapes is not None else False,
1857
1858
  )
1858
1859
  self.torch_profiler.start()
1859
1860
 
1860
1861
  if "MEM" in activities:
1861
1862
  torch.cuda.memory._record_memory_history(max_entries=100000)
1862
1863
 
1864
+ if "CUDA_PROFILER" in activities:
1865
+ torch.cuda.cudart().cudaProfilerStart()
1866
+
1863
1867
  if num_steps:
1864
1868
  self.profiler_target_forward_ct = self.forward_ct + num_steps
1865
1869
  # The caller will be notified when reaching profiler_target_forward_ct
@@ -1868,7 +1872,7 @@ class Scheduler(
1868
1872
  return ProfileReqOutput(success=True, message="Succeeded")
1869
1873
 
1870
1874
  def stop_profile(self) -> None:
1871
- if self.torch_profiler_activities is None:
1875
+ if self.profiler_activities is None:
1872
1876
  return
1873
1877
 
1874
1878
  logger.info("Stop profiling...")
@@ -1881,21 +1885,24 @@ class Scheduler(
1881
1885
  )
1882
1886
  )
1883
1887
 
1884
- if "MEM" in self.torch_profiler_activities:
1888
+ if "MEM" in self.profiler_activities:
1885
1889
  memory_profile_path = os.path.join(
1886
- self.torch_profiler_trace_dir,
1890
+ self.torch_profiler_output_dir,
1887
1891
  str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
1888
1892
  )
1889
1893
  torch.cuda.memory._dump_snapshot(memory_profile_path)
1890
1894
  torch.cuda.memory._record_memory_history(enabled=None)
1891
1895
 
1896
+ if "CUDA_PROFILER" in self.profiler_activities:
1897
+ torch.cuda.cudart().cudaProfilerStop()
1898
+
1892
1899
  logger.info(
1893
1900
  "Profiling done. Traces are saved to: %s",
1894
1901
  self.torch_profiler_output_dir,
1895
1902
  )
1896
1903
  self.torch_profiler = None
1897
1904
  self.torch_profiler_output_dir = None
1898
- self.torch_profiler_activities = None
1905
+ self.profiler_activities = None
1899
1906
 
1900
1907
  if self.profiler_target_forward_ct:
1901
1908
  self.send_to_tokenizer.send_pyobj(
@@ -1963,7 +1970,6 @@ def run_scheduler_process(
1963
1970
  dp_rank: Optional[int],
1964
1971
  pipe_writer,
1965
1972
  ):
1966
-
1967
1973
  # Generate the prefix
1968
1974
  if dp_rank is None:
1969
1975
  prefix = f" TP{tp_rank}"
@@ -261,7 +261,6 @@ class TokenizerManager:
261
261
  self.start_profile_communicator = _Communicator(
262
262
  self.send_to_scheduler, server_args.dp_size
263
263
  )
264
- self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
265
264
  self.get_internal_state_communicator = _Communicator(
266
265
  self.send_to_scheduler, server_args.dp_size
267
266
  )
@@ -737,7 +736,7 @@ class TokenizerManager:
737
736
  self.auto_create_handle_loop()
738
737
  assert (
739
738
  self.server_args.dp_size == 1
740
- ), "dp_size must be for update weights from distributed"
739
+ ), "dp_size must be 1 for update weights from distributed"
741
740
 
742
741
  # This means that weight sync
743
742
  # cannot run while requests are in progress.
@@ -132,6 +132,9 @@ class TpModelWorker:
132
132
  )[0]
133
133
  set_random_seed(self.random_seed)
134
134
 
135
+ # A reference make this class has the same member as TpModelWorkerClient
136
+ self.worker = self
137
+
135
138
  def get_worker_info(self):
136
139
  return (
137
140
  self.max_total_num_tokens,
@@ -1,11 +1,6 @@
1
- import json
2
1
  import logging
3
- import time
4
- from collections import defaultdict
5
2
  from http import HTTPStatus
6
- from typing import Dict, List, Optional, Tuple
7
-
8
- import torch
3
+ from typing import Optional
9
4
 
10
5
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
11
6
 
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
16
16
  TokenToKVPoolAllocator,
17
17
  )
18
18
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
19
- from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
20
19
 
21
20
  logger = logging.getLogger(__name__)
22
21
 
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
31
30
  page_size: int,
32
31
  hicache_ratio: float,
33
32
  ):
34
- if page_size != 1:
35
- raise ValueError(
36
- "Page size larger than 1 is not yet supported in HiRadixCache."
37
- )
38
33
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
39
34
  if isinstance(self.kv_cache, MHATokenToKVPool):
40
35
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
41
- self.kv_cache, hicache_ratio
36
+ self.kv_cache, hicache_ratio, page_size
42
37
  )
43
38
  elif isinstance(self.kv_cache, MLATokenToKVPool):
44
39
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
45
- self.kv_cache, hicache_ratio
40
+ self.kv_cache, hicache_ratio, page_size
46
41
  )
47
42
  else:
48
- raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
43
+ raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
49
44
 
50
45
  self.tp_group = tp_cache_group
51
- self.page_size = page_size
52
46
 
53
47
  self.load_cache_event = threading.Event()
54
48
  self.cache_controller = HiCacheController(
55
49
  token_to_kv_pool_allocator,
56
50
  self.token_to_kv_pool_host,
51
+ page_size,
57
52
  load_cache_event=self.load_cache_event,
58
53
  )
59
54
 
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
65
60
  self.write_through_threshold = 1
66
61
  self.load_back_threshold = 10
67
62
  super().__init__(
68
- req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
63
+ req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
69
64
  )
70
65
 
71
66
  def reset(self):
@@ -210,9 +205,9 @@ class HiRadixCache(RadixCache):
210
205
  # only evict the host value of evicted nodes
211
206
  if not x.evicted:
212
207
  continue
213
- assert x.lock_ref == 0 and x.host_value is not None
214
208
 
215
- assert self.cache_controller.evict_host(x.host_value) > 0
209
+ num_evicted += self.cache_controller.evict_host(x.host_value)
210
+
216
211
  for k, v in x.parent.children.items():
217
212
  if v == x:
218
213
  break
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
299
294
 
300
295
  return last_node, prefix_indices
301
296
 
302
- def read_to_load_cache(self):
297
+ def ready_to_load_cache(self):
303
298
  self.load_cache_event.set()
304
299
 
305
300
  def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
306
- if self.disable:
307
- return [], self.root_node
301
+ empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
302
+ if self.disable or len(key) == 0:
303
+ if include_evicted:
304
+ return empty_value, self.root_node, self.root_node
305
+ else:
306
+ return empty_value, self.root_node
307
+
308
+ if self.page_size != 1:
309
+ page_aligned_len = len(key) // self.page_size * self.page_size
310
+ key = key[:page_aligned_len]
308
311
 
309
312
  value, last_node = self._match_prefix_helper(self.root_node, key)
310
313
  if value:
311
314
  value = torch.cat(value)
312
315
  else:
313
- value = torch.tensor([], dtype=torch.int64)
316
+ value = empty_value
314
317
 
315
318
  last_node_global = last_node
316
319
  while last_node.evicted:
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
323
326
 
324
327
  def _match_prefix_helper(self, node: TreeNode, key: List):
325
328
  node.last_access_time = time.time()
329
+ child_key = self.get_child_key_fn(key)
326
330
  value = []
327
- while len(key) > 0 and key[0] in node.children.keys():
328
- child = node.children[key[0]]
331
+
332
+ while len(key) > 0 and child_key in node.children.keys():
333
+ child = node.children[child_key]
329
334
  child.last_access_time = time.time()
330
- prefix_len = _key_match(child.key, key)
335
+ prefix_len = self.key_match_fn(child.key, key)
331
336
  if prefix_len < len(child.key):
332
337
  new_node = self._split_node(child.key, child, prefix_len)
333
338
  if not new_node.evicted:
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
339
344
  value.append(child.value)
340
345
  node = child
341
346
  key = key[prefix_len:]
347
+
348
+ if len(key):
349
+ child_key = self.get_child_key_fn(key)
350
+
342
351
  return value, node
343
352
 
344
353
  def _split_node(self, key, child: TreeNode, split_len: int):
345
354
  # child node split into new_node -> child
346
355
  new_node = TreeNode()
347
- new_node.children = {key[split_len]: child}
356
+ new_node.children = {self.get_child_key_fn(key[split_len:]): child}
348
357
  new_node.parent = child.parent
349
358
  new_node.lock_ref = child.lock_ref
350
359
  new_node.key = child.key[:split_len]
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
361
370
  child.host_value = child.host_value[split_len:]
362
371
  child.parent = new_node
363
372
  child.key = child.key[split_len:]
364
- new_node.parent.children[key[0]] = new_node
373
+ new_node.parent.children[self.get_child_key_fn(key)] = new_node
365
374
  return new_node
366
375
 
367
376
  def _insert_helper(self, node: TreeNode, key: List, value):
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
369
378
  if len(key) == 0:
370
379
  return 0
371
380
 
372
- if key[0] in node.children.keys():
373
- child = node.children[key[0]]
374
- prefix_len = _key_match(child.key, key)
381
+ child_key = self.get_child_key_fn(key)
382
+ total_prefix_length = 0
375
383
 
376
- if prefix_len == len(child.key):
377
- if child.evicted:
384
+ while len(key) > 0 and child_key in node.children.keys():
385
+ node = node.children[child_key]
386
+ node.last_access_time = time.time()
387
+ prefix_len = self.key_match_fn(node.key, key)
388
+
389
+ if prefix_len == len(node.key):
390
+ if node.evicted:
378
391
  # change the reference if the node is evicted
379
392
  # this often happens in the case of KV cache recomputation
380
- child.value = value[:prefix_len]
381
- self.token_to_kv_pool_host.update_synced(child.host_value)
382
- self.evictable_size_ += len(value[:prefix_len])
383
- return self._insert_helper(
384
- child, key[prefix_len:], value[prefix_len:]
385
- )
393
+ node.value = value[:prefix_len]
394
+ self.token_to_kv_pool_host.update_synced(node.host_value)
395
+ self.evictable_size_ += len(node.value)
386
396
  else:
387
- self.inc_hit_count(child)
388
- return prefix_len + self._insert_helper(
389
- child, key[prefix_len:], value[prefix_len:]
390
- )
391
-
392
- # partial match, split the node
393
- new_node = self._split_node(child.key, child, prefix_len)
394
- if new_node.evicted:
395
- new_node.value = value[:prefix_len]
396
- self.token_to_kv_pool_host.update_synced(new_node.host_value)
397
- self.evictable_size_ += len(new_node.value)
398
- return self._insert_helper(
399
- new_node, key[prefix_len:], value[prefix_len:]
400
- )
397
+ self.inc_hit_count(node)
398
+ total_prefix_length += prefix_len
401
399
  else:
402
- self.inc_hit_count(new_node)
403
- return prefix_len + self._insert_helper(
404
- new_node, key[prefix_len:], value[prefix_len:]
405
- )
400
+ # partial match, split the node
401
+ new_node = self._split_node(node.key, node, prefix_len)
402
+ if new_node.evicted:
403
+ new_node.value = value[:prefix_len]
404
+ self.token_to_kv_pool_host.update_synced(new_node.host_value)
405
+ self.evictable_size_ += len(new_node.value)
406
+ else:
407
+ self.inc_hit_count(new_node)
408
+ total_prefix_length += prefix_len
409
+ node = new_node
410
+
411
+ key = key[prefix_len:]
412
+ value = value[prefix_len:]
413
+
414
+ if len(key):
415
+ child_key = self.get_child_key_fn(key)
406
416
 
407
417
  if len(key):
408
418
  new_node = TreeNode()
409
419
  new_node.parent = node
410
420
  new_node.key = key
411
421
  new_node.value = value
412
- node.children[key[0]] = new_node
422
+ node.children[child_key] = new_node
413
423
  self.evictable_size_ += len(value)
414
424
 
415
425
  if self.cache_controller.write_policy == "write_through":
416
426
  self.write_backup(new_node)
417
- return 0
427
+ return total_prefix_length
418
428
 
419
429
  def _collect_leaves_device(self):
420
430
  def is_leaf(node):
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
185
185
  if self.free_group:
186
186
  self.free(torch.cat(self.free_group))
187
187
 
188
+ def backup_state(self):
189
+ return self.free_slots
190
+
191
+ def restore_state(self, free_slots):
192
+ self.free_slots = free_slots
193
+
188
194
  def clear(self):
189
195
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
190
196
  self.free_slots = torch.arange(
@@ -602,8 +608,9 @@ class HostKVCache(abc.ABC):
602
608
  self,
603
609
  device_pool: MHATokenToKVPool,
604
610
  host_to_device_ratio: float,
605
- pin_memory: bool = False, # no need to use pin memory with the double buffering
606
- device: str = "cpu",
611
+ pin_memory: bool,
612
+ device: str,
613
+ page_size: int,
607
614
  ):
608
615
  assert (
609
616
  host_to_device_ratio >= 1
@@ -614,8 +621,11 @@ class HostKVCache(abc.ABC):
614
621
  self.host_to_device_ratio = host_to_device_ratio
615
622
  self.pin_memory = pin_memory
616
623
  self.device = device
624
+ self.page_size = page_size
617
625
 
618
626
  self.size = int(device_pool.size * host_to_device_ratio)
627
+ # Align the host memory pool size to the page size
628
+ self.size = self.size - (self.size % self.page_size)
619
629
  self.dtype = device_pool.store_dtype
620
630
  self.size_per_token = self.get_size_per_token()
621
631
 
@@ -769,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
769
779
  self,
770
780
  device_pool: MHATokenToKVPool,
771
781
  host_to_device_ratio: float,
772
- pin_memory: bool = False, # no need to use pin memory with the double buffering
782
+ page_size: int,
783
+ pin_memory: bool = True,
773
784
  device: str = "cpu",
774
785
  ):
775
- super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
786
+ super().__init__(
787
+ device_pool, host_to_device_ratio, pin_memory, device, page_size
788
+ )
776
789
 
777
790
  def get_size_per_token(self):
778
791
  self.head_num = self.device_pool.head_num
@@ -805,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
805
818
  def assign_flat_data(self, indices, flat_data):
806
819
  self.kv_buffer[:, :, indices] = flat_data
807
820
 
821
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
822
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
823
+ for i in range(len(device_indices_cpu)):
824
+ h_index = host_indices[i * self.page_size]
825
+ d_index = device_indices_cpu[i]
826
+ for j in range(self.layer_num):
827
+ self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
828
+ device_pool.k_buffer[j][d_index : d_index + self.page_size],
829
+ non_blocking=True,
830
+ )
831
+ self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
832
+ device_pool.v_buffer[j][d_index : d_index + self.page_size],
833
+ non_blocking=True,
834
+ )
835
+
836
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
837
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
838
+ for i in range(len(device_indices_cpu)):
839
+ h_index = host_indices[i * self.page_size]
840
+ d_index = device_indices_cpu[i]
841
+ device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
842
+ self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
843
+ non_blocking=True,
844
+ )
845
+ device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
846
+ self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
847
+ non_blocking=True,
848
+ )
849
+
808
850
 
809
851
  class MLATokenToKVPoolHost(HostKVCache):
810
852
  def __init__(
811
853
  self,
812
854
  device_pool: MLATokenToKVPool,
813
855
  host_to_device_ratio: float,
814
- pin_memory: bool = False, # no need to use pin memory with the double buffering
856
+ page_size: int,
857
+ pin_memory: bool = True,
815
858
  device: str = "cpu",
816
859
  ):
817
- super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
860
+ super().__init__(
861
+ device_pool, host_to_device_ratio, pin_memory, device, page_size
862
+ )
818
863
 
819
864
  def get_size_per_token(self):
820
865
  self.kv_lora_rank = self.device_pool.kv_lora_rank
@@ -851,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
851
896
 
852
897
  def assign_flat_data(self, indices, flat_data):
853
898
  self.kv_buffer[:, indices] = flat_data
899
+
900
+ def write_page_all_layers(self, host_indices, device_indices, device_pool):
901
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
902
+ for i in range(len(device_indices_cpu)):
903
+ h_index = host_indices[i * self.page_size]
904
+ d_index = device_indices_cpu[i]
905
+ for j in range(self.layer_num):
906
+ self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
907
+ device_pool.kv_buffer[j][d_index : d_index + self.page_size],
908
+ non_blocking=True,
909
+ )
910
+
911
+ def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
912
+ device_indices_cpu = device_indices[:: self.page_size].cpu()
913
+ for i in range(len(device_indices_cpu)):
914
+ h_index = host_indices[i * self.page_size]
915
+ d_index = device_indices_cpu[i]
916
+ device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
917
+ self.kv_buffer[layer_id, h_index : h_index + self.page_size],
918
+ non_blocking=True,
919
+ )
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
190
190
  def available_size(self):
191
191
  return len(self.free_pages) * self.page_size
192
192
 
193
+ def get_kvcache(self):
194
+ return self._kvcache
195
+
196
+ def alloc(self, need_size: int):
197
+ # page-aligned allocation, returning contiguous indices of pages
198
+ if self.debug_mode:
199
+ assert (
200
+ need_size % self.page_size == 0
201
+ ), "The allocation size should be page-aligned"
202
+
203
+ num_pages = need_size // self.page_size
204
+ if num_pages > len(self.free_pages):
205
+ return None
206
+
207
+ out_pages = self.free_pages[:num_pages]
208
+ self.free_pages = self.free_pages[num_pages:]
209
+
210
+ out_indices = (
211
+ out_pages[:, None] * self.page_size
212
+ + torch.arange(self.page_size, device=self.device)
213
+ ).reshape(-1)
214
+
215
+ return out_indices
216
+
193
217
  def alloc_extend(
194
218
  self,
195
219
  prefix_lens: torch.Tensor,
@@ -218,6 +242,9 @@ class PagedTokenToKVPoolAllocator:
218
242
  next_power_of_2(extend_num_tokens),
219
243
  )
220
244
 
245
+ if self.debug_mode:
246
+ assert len(torch.unique(out_indices)) == len(out_indices)
247
+
221
248
  merged_value = self.ret_values.item()
222
249
  num_new_pages = merged_value >> 32
223
250
  if num_new_pages > len(self.free_pages):
@@ -248,6 +275,9 @@ class PagedTokenToKVPoolAllocator:
248
275
  self.page_size,
249
276
  )
250
277
 
278
+ if self.debug_mode:
279
+ assert len(torch.unique(out_indices)) == len(out_indices)
280
+
251
281
  num_new_pages = self.ret_values.item()
252
282
  if num_new_pages > len(self.free_pages):
253
283
  return None
@@ -265,6 +295,9 @@ class PagedTokenToKVPoolAllocator:
265
295
  else:
266
296
  self.free_group.append(free_index)
267
297
 
298
+ if self.debug_mode:
299
+ assert len(torch.unique(self.free_pages)) == len(self.free_pages)
300
+
268
301
  def free_group_begin(self):
269
302
  self.is_not_in_free_group = False
270
303
  self.free_group = []
@@ -274,6 +307,12 @@ class PagedTokenToKVPoolAllocator:
274
307
  if self.free_group:
275
308
  self.free(torch.cat(self.free_group))
276
309
 
310
+ def backup_state(self):
311
+ return self.free_pages
312
+
313
+ def restore_state(self, free_pages):
314
+ self.free_pages = free_pages
315
+
277
316
  def clear(self):
278
317
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
279
318
  self.free_pages = torch.arange(