sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ import uuid
29
29
  from collections import deque
30
30
  from contextlib import nullcontext
31
31
  from datetime import datetime
32
+ from enum import Enum
32
33
  from http import HTTPStatus
33
34
  from typing import (
34
35
  Any,
@@ -70,7 +71,6 @@ from sglang.srt.managers.io_struct import (
70
71
  BatchMultimodalOut,
71
72
  BatchStrOut,
72
73
  BatchTokenIDOut,
73
- BlockReqType,
74
74
  CloseSessionReqInput,
75
75
  ConfigureLoggingReq,
76
76
  EmbeddingReqInput,
@@ -116,6 +116,7 @@ from sglang.srt.managers.io_struct import (
116
116
  )
117
117
  from sglang.srt.managers.mm_utils import TensorTransportMode
118
118
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
119
+ from sglang.srt.managers.scheduler import is_health_check_generate_req
119
120
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
120
121
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
121
122
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -202,13 +203,29 @@ class TokenizerManager:
202
203
 
203
204
  if self.model_config.is_multimodal:
204
205
  import_processors()
205
- _processor = get_processor(
206
- server_args.tokenizer_path,
207
- tokenizer_mode=server_args.tokenizer_mode,
208
- trust_remote_code=server_args.trust_remote_code,
209
- revision=server_args.revision,
210
- use_fast=not server_args.disable_fast_image_processor,
211
- )
206
+ try:
207
+ _processor = get_processor(
208
+ server_args.tokenizer_path,
209
+ tokenizer_mode=server_args.tokenizer_mode,
210
+ trust_remote_code=server_args.trust_remote_code,
211
+ revision=server_args.revision,
212
+ use_fast=not server_args.disable_fast_image_processor,
213
+ )
214
+ except ValueError as e:
215
+ error_message = str(e)
216
+ if "does not have a slow version" in error_message:
217
+ logger.info(
218
+ f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
219
+ )
220
+ _processor = get_processor(
221
+ server_args.tokenizer_path,
222
+ tokenizer_mode=server_args.tokenizer_mode,
223
+ trust_remote_code=server_args.trust_remote_code,
224
+ revision=server_args.revision,
225
+ use_fast=True,
226
+ )
227
+ else:
228
+ raise e
212
229
  transport_mode = _determine_tensor_transport_mode(self.server_args)
213
230
 
214
231
  # We want to parallelize the image pre-processing so we create an executor for it
@@ -225,10 +242,10 @@ class TokenizerManager:
225
242
  self.tokenizer = get_tokenizer_from_processor(self.processor)
226
243
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
227
244
  else:
228
- self.mm_processor = None
245
+ self.mm_processor = self.processor = None
229
246
 
230
247
  if server_args.skip_tokenizer_init:
231
- self.tokenizer = self.processor = None
248
+ self.tokenizer = None
232
249
  else:
233
250
  self.tokenizer = get_tokenizer(
234
251
  server_args.tokenizer_path,
@@ -255,6 +272,7 @@ class TokenizerManager:
255
272
  self.health_check_failed = False
256
273
  self.gracefully_exit = False
257
274
  self.last_receive_tstamp = 0
275
+ self.server_status = ServerStatus.Starting
258
276
 
259
277
  # Dumping
260
278
  self.dump_requests_folder = "" # By default do not dump
@@ -1069,38 +1087,56 @@ class TokenizerManager:
1069
1087
  _: Optional[fastapi.Request] = None,
1070
1088
  ) -> LoadLoRAAdapterReqOutput:
1071
1089
  self.auto_create_handle_loop()
1072
- if not self.server_args.enable_lora:
1073
- raise ValueError(
1074
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1075
- )
1076
1090
 
1077
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1078
- # with dp_size > 1.
1079
- assert (
1080
- self.server_args.dp_size == 1
1081
- ), "dp_size must be 1 for dynamic lora loading"
1082
- logger.info(
1083
- "Start load Lora adapter. Lora name=%s, path=%s",
1084
- obj.lora_name,
1085
- obj.lora_path,
1086
- )
1091
+ try:
1092
+ if not self.server_args.enable_lora:
1093
+ raise ValueError(
1094
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1095
+ )
1087
1096
 
1088
- async with self.lora_update_lock:
1089
- # Generate new uniquely identifiable LoRARef object.
1090
- new_adapter = LoRARef(
1091
- lora_name=obj.lora_name,
1092
- lora_path=obj.lora_path,
1097
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1098
+ # with dp_size > 1.
1099
+ assert (
1100
+ self.server_args.dp_size == 1
1101
+ ), "dp_size must be 1 for dynamic lora loading"
1102
+ logger.info(
1103
+ "Start load Lora adapter. Lora name=%s, path=%s",
1104
+ obj.lora_name,
1105
+ obj.lora_path,
1093
1106
  )
1094
1107
 
1095
- # Trigger the actual loading operation at the backend processes.
1096
- obj.lora_id = new_adapter.lora_id
1097
- result = (await self.update_lora_adapter_communicator(obj))[0]
1108
+ async with self.lora_update_lock:
1109
+ if (
1110
+ self.server_args.max_loaded_loras is not None
1111
+ and self.lora_registry.num_registered_loras
1112
+ >= self.server_args.max_loaded_loras
1113
+ ):
1114
+ raise ValueError(
1115
+ f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1116
+ f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1117
+ "Please unload some LoRA adapters before loading new ones."
1118
+ )
1098
1119
 
1099
- # Register the LoRA adapter only after loading is successful.
1100
- if result.success:
1101
- await self.lora_registry.register(new_adapter)
1120
+ # Generate new uniquely identifiable LoRARef object.
1121
+ new_adapter = LoRARef(
1122
+ lora_name=obj.lora_name,
1123
+ lora_path=obj.lora_path,
1124
+ )
1102
1125
 
1103
- return result
1126
+ # Trigger the actual loading operation at the backend processes.
1127
+ obj.lora_id = new_adapter.lora_id
1128
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1129
+
1130
+ # Register the LoRA adapter only after loading is successful.
1131
+ if result.success:
1132
+ await self.lora_registry.register(new_adapter)
1133
+
1134
+ return result
1135
+ except ValueError as e:
1136
+ return LoadLoRAAdapterReqOutput(
1137
+ success=False,
1138
+ error_message=str(e),
1139
+ )
1104
1140
 
1105
1141
  async def unload_lora_adapter(
1106
1142
  self,
@@ -1108,37 +1144,41 @@ class TokenizerManager:
1108
1144
  _: Optional[fastapi.Request] = None,
1109
1145
  ) -> UnloadLoRAAdapterReqOutput:
1110
1146
  self.auto_create_handle_loop()
1111
- if not self.server_args.enable_lora:
1112
- raise ValueError(
1113
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1114
- )
1115
1147
 
1116
- assert (
1117
- obj.lora_name is not None
1118
- ), "lora_name must be provided to unload LoRA adapter"
1148
+ try:
1149
+ if not self.server_args.enable_lora:
1150
+ raise ValueError(
1151
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1152
+ )
1119
1153
 
1120
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1121
- # with dp_size > 1.
1122
- assert (
1123
- self.server_args.dp_size == 1
1124
- ), "dp_size must be 1 for dynamic lora loading"
1125
- logger.info(
1126
- "Start unload Lora adapter. Lora name=%s",
1127
- obj.lora_name,
1128
- )
1154
+ assert (
1155
+ obj.lora_name is not None
1156
+ ), "lora_name must be provided to unload LoRA adapter"
1157
+
1158
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1159
+ # with dp_size > 1.
1160
+ assert (
1161
+ self.server_args.dp_size == 1
1162
+ ), "dp_size must be 1 for dynamic lora loading"
1163
+ logger.info(
1164
+ "Start unload Lora adapter. Lora name=%s",
1165
+ obj.lora_name,
1166
+ )
1129
1167
 
1130
- async with self.lora_update_lock:
1131
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1132
- # from being started.
1133
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1134
- obj.lora_id = lora_id
1168
+ async with self.lora_update_lock:
1169
+ # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1170
+ # from being started.
1171
+ lora_id = await self.lora_registry.unregister(obj.lora_name)
1172
+ obj.lora_id = lora_id
1135
1173
 
1136
- # Initiate the actual unloading operation at the backend processes only after all
1137
- # ongoing requests using this LoRA adapter are finished.
1138
- await self.lora_registry.wait_for_unload(lora_id)
1139
- result = (await self.update_lora_adapter_communicator(obj))[0]
1174
+ # Initiate the actual unloading operation at the backend processes only after all
1175
+ # ongoing requests using this LoRA adapter are finished.
1176
+ await self.lora_registry.wait_for_unload(lora_id)
1177
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1140
1178
 
1141
- return result
1179
+ return result
1180
+ except ValueError as e:
1181
+ return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
1142
1182
 
1143
1183
  async def get_weights_by_name(
1144
1184
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -1767,6 +1807,8 @@ class TokenizerManager:
1767
1807
  asyncio.create_task(asyncio.to_thread(background_task))
1768
1808
 
1769
1809
  def _handle_abort_req(self, recv_obj):
1810
+ if is_health_check_generate_req(recv_obj):
1811
+ return
1770
1812
  state = self.rid_to_state[recv_obj.rid]
1771
1813
  state.finished = True
1772
1814
  if recv_obj.finished_reason:
@@ -1901,6 +1943,16 @@ class TokenizerManager:
1901
1943
  return scores
1902
1944
 
1903
1945
 
1946
+ class ServerStatus(Enum):
1947
+ Up = "Up"
1948
+ Starting = "Starting"
1949
+ UnHealthy = "UnHealthy"
1950
+ Crashed = "Crashed"
1951
+
1952
+ def is_healthy(self) -> bool:
1953
+ return self == ServerStatus.Up
1954
+
1955
+
1904
1956
  def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
1905
1957
  is_cross_node = server_args.dist_init_addr
1906
1958
 
@@ -1,6 +1,7 @@
1
1
  import logging
2
+ import multiprocessing as mp
2
3
  from http import HTTPStatus
3
- from typing import Optional
4
+ from typing import Dict, List, Optional
4
5
 
5
6
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
7
 
@@ -38,3 +39,46 @@ def validate_input_length(
38
39
  return error_msg
39
40
 
40
41
  return None
42
+
43
+
44
+ class DPBalanceMeta:
45
+ """
46
+ This class will be use in scheduler and dp controller
47
+ """
48
+
49
+ def __init__(self, num_workers: int):
50
+ self.num_workers = num_workers
51
+ self._manager = mp.Manager()
52
+ self.mutex = self._manager.Lock()
53
+
54
+ init_local_tokens = [0] * self.num_workers
55
+ init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
56
+
57
+ self.shared_state = self._manager.Namespace()
58
+ self.shared_state.local_tokens = self._manager.list(init_local_tokens)
59
+ self.shared_state.onfly_info = self._manager.list(init_onfly_info)
60
+
61
+ def destructor(self):
62
+ # we must destructor this class manually
63
+ self._manager.shutdown()
64
+
65
+ def get_shared_onfly(self) -> List[Dict[int, int]]:
66
+ return [dict(d) for d in self.shared_state.onfly_info]
67
+
68
+ def set_shared_onfly_info(self, data: List[Dict[int, int]]):
69
+ self.shared_state.onfly_info = data
70
+
71
+ def get_shared_local_tokens(self) -> List[int]:
72
+ return list(self.shared_state.local_tokens)
73
+
74
+ def set_shared_local_tokens(self, data: List[int]):
75
+ self.shared_state.local_tokens = data
76
+
77
+ def __getstate__(self):
78
+ state = self.__dict__.copy()
79
+ del state["_manager"]
80
+ return state
81
+
82
+ def __setstate__(self, state):
83
+ self.__dict__.update(state)
84
+ self._manager = None
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch.utils.cpp_extension import load
8
+
9
+ _abs_path = os.path.dirname(os.path.abspath(__file__))
10
+ radix_tree_cpp = load(
11
+ name="radix_tree_cpp",
12
+ sources=[
13
+ f"{_abs_path}/tree_v2_binding.cpp",
14
+ f"{_abs_path}/tree_v2_debug.cpp",
15
+ f"{_abs_path}/tree_v2.cpp",
16
+ ],
17
+ extra_cflags=["-O3", "-std=c++20"],
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+
22
+ class TreeNodeCpp:
23
+ """
24
+ A placeholder for the TreeNode class. Cannot be constructed elsewhere.
25
+ """
26
+
27
+ class IOHandle:
28
+ """
29
+ A placeholder for the IOHandle class. Cannot be constructed elsewhere.
30
+ """
31
+
32
+ class RadixTreeCpp:
33
+ def __init__(
34
+ self,
35
+ disabled: bool,
36
+ host_size: Optional[int],
37
+ page_size: int,
38
+ write_through_threshold: int,
39
+ ):
40
+ """
41
+ Initializes the RadixTreeCpp instance.
42
+ Args:
43
+ disabled (bool): If True, the radix tree is disabled.
44
+ host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
45
+ page_size (int): Size of the page for the radix tree.
46
+ write_through_threshold (int): Threshold for writing through from GPU to CPU.
47
+ """
48
+ self.tree = radix_tree_cpp.RadixTree( # type: ignore
49
+ disabled, host_size, page_size, write_through_threshold
50
+ )
51
+
52
+ def match_prefix(
53
+ self, prefix: List[int]
54
+ ) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
55
+ """
56
+ Matches a prefix in the radix tree.
57
+ Args:
58
+ prefix (List[int]): The prefix to match.
59
+ Returns:
60
+ Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
61
+ 0. A list of indices that is matched by the prefix on the GPU.
62
+ 1. Sum length of the indices matched on the CPU.
63
+ 2. The last node of the prefix matched on the GPU.
64
+ 3. The last node of the prefix matched on the CPU.
65
+ """
66
+ return self.tree.match_prefix(prefix)
67
+
68
+ def evict(self, num_tokens: int) -> List[torch.Tensor]:
69
+ """
70
+ Evicts a number of tokens from the radix tree.
71
+ Args:
72
+ num_tokens (int): The number of tokens to evict.
73
+ Returns:
74
+ List[torch.Tensor]: A list of indices that were evicted.
75
+ """
76
+ return self.tree.evict(num_tokens)
77
+
78
+ def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
79
+ """
80
+ Locks or unlocks a reference to a tree node.
81
+ After locking, the node will not be evicted from the radix tree.
82
+ Args:
83
+ handle (TreeNodeCpp): The tree node to lock or unlock.
84
+ lock (bool): If True, locks the node; if False, unlocks it.
85
+ """
86
+ return self.tree.lock_ref(handle, lock)
87
+
88
+ def writing_through(
89
+ self, key: List[int], indices: torch.Tensor
90
+ ) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
91
+ """
92
+ Inserts a key-value pair into the radix tree and perform write-through check.
93
+ Args:
94
+ key (List[int]): The key to insert.
95
+ indices (torch.Tensor): The value associated with the key.
96
+ Returns:
97
+ Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
98
+ 0. A list of (IOHandle, device indices, host indices) tuples.
99
+ These IOhandles require write-through to the CPU in python side.
100
+ 1. The number of indices that are matched on device.
101
+ """
102
+ return self.tree.writing_through(key, indices)
103
+
104
+ def loading_onboard(
105
+ self,
106
+ host_node: TreeNodeCpp,
107
+ new_device_indices: torch.Tensor,
108
+ ) -> Tuple[IOHandle, List[torch.Tensor]]:
109
+ """
110
+ Updates the device indices of tree nodes within a range on the tree.
111
+ Args:
112
+ host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
113
+ new_device_indices (torch.Tensor): The new device indices to set.
114
+ The length of this tensor must be exactly host indices length.
115
+ Returns:
116
+ Tuple[IOHandle, List[torch.Tensor]]:
117
+ 0. An IOHandle that requires loading to the CPU in python side.
118
+ 1. A list of host indices corresponding to the new device indices.
119
+ """
120
+ return self.tree.loading_onboard(host_node, new_device_indices)
121
+
122
+ def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
123
+ """
124
+ Commits the write-through process for a tree node.
125
+ Args:
126
+ handle (IOHandle): The IOHandle to commit.
127
+ success (bool): If True, commits the write-through; if False, just indicates failure.
128
+ """
129
+ return self.tree.commit_writing_through(handle, success)
130
+
131
+ def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
132
+ """
133
+ Commits the load onboard process for tree nodes within a range on the tree.
134
+ Args:
135
+ handle (IOHandle): The IOHandle to commit.
136
+ success (bool): If True, commits the load-onboard; if False, just indicates failure.
137
+ """
138
+ return self.tree.commit_loading_onboard(handle, success)
139
+
140
+ def evictable_size(self) -> int:
141
+ """
142
+ Returns the size of the evictable part of the radix tree.
143
+ This is the size of the part that can be evicted from the GPU (ref_count = 0).
144
+ Returns:
145
+ int: The size of the evictable part.
146
+ """
147
+ return self.tree.evictable_size()
148
+
149
+ def protected_size(self) -> int:
150
+ """
151
+ Returns the size of the protected part of the radix tree.
152
+ This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
153
+ Returns:
154
+ int: The size of the protected part.
155
+ """
156
+ return self.tree.protected_size()
157
+
158
+ def total_size(self) -> int:
159
+ """
160
+ Returns the total size of the radix tree (including CPU nodes).
161
+ Returns:
162
+ int: The total size of the radix tree.
163
+ """
164
+ return self.tree.total_size()
165
+
166
+ def reset(self) -> None:
167
+ """
168
+ Resets the radix tree, clearing all nodes and indices.
169
+ """
170
+ return self.tree.reset()
171
+
172
+ def debug_print(self) -> None:
173
+ """
174
+ Prints the internal state of the radix tree for debugging purposes.
175
+ """
176
+ return self.tree.debug_print()
177
+
178
+ else:
179
+ # Real implementation of the classes for runtime
180
+ RadixTreeCpp = radix_tree_cpp.RadixTree
181
+ TreeNodeCpp = object
182
+ IOHandle = object
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
33
33
  It abstracts the underlying storage mechanism, allowing different implementations to be used.
34
34
  """
35
35
 
36
- # todo, translate tensor object access for different TP ranks
37
- # potentially pass model and TP configs into storage backend
36
+ # todo, potentially pass model and TP configs into storage backend
38
37
  # todo, the page size of storage backend does not have to be the same as the same as host memory pool
39
38
 
40
39
  @abstractmethod
@@ -117,26 +116,28 @@ class HiCacheFile(HiCacheStorage):
117
116
  def get(
118
117
  self,
119
118
  key: str,
120
- target_location: Optional[Any] = None,
119
+ target_location: torch.Tensor,
121
120
  target_sizes: Optional[Any] = None,
122
121
  ) -> torch.Tensor | None:
123
122
  key = self._get_suffixed_key(key)
124
123
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
125
124
  try:
126
- # todo: fixing the target_location logic to enable in-place loading
127
- loaded_tensor = torch.load(tensor_path)
128
- if isinstance(loaded_tensor, torch.Tensor):
129
- return loaded_tensor
130
- else:
131
- logger.error(f"Loaded data for key {key} is not a tensor.")
132
- return None
125
+ # Load directly into target_location's memory buffer
126
+ with open(tensor_path, "rb") as f:
127
+ target_location.set_(
128
+ torch.frombuffer(f.read(), dtype=target_location.dtype)
129
+ .reshape(target_location.shape)
130
+ .untyped_storage()
131
+ )
132
+ return target_location
133
133
  except FileNotFoundError:
134
+ logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
134
135
  return None
135
136
 
136
137
  def batch_get(
137
138
  self,
138
139
  keys: List[str],
139
- target_locations: Optional[Any] = None,
140
+ target_locations: List[torch.Tensor],
140
141
  target_sizes: Optional[Any] = None,
141
142
  ) -> List[torch.Tensor | None]:
142
143
  return [
@@ -159,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
159
160
  logger.debug(f"Key {key} already exists. Skipped.")
160
161
  return True
161
162
  try:
162
- torch.save(value, tensor_path)
163
+ value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
163
164
  return True
164
165
  except Exception as e:
165
166
  logger.error(f"Failed to save tensor {key}: {e}")
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
35
35
  hicache_size: int,
36
36
  hicache_write_policy: str,
37
37
  hicache_io_backend: str,
38
+ hicache_mem_layout: str,
38
39
  hicache_storage_backend: Optional[str] = None,
39
40
  ):
41
+
42
+ if hicache_io_backend == "direct":
43
+ if hicache_mem_layout == "page_first":
44
+ hicache_mem_layout = "layer_first"
45
+ logger.warning(
46
+ "Page first layout is not supported with direct IO backend, switching to layer first layout"
47
+ )
48
+
40
49
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
41
50
  if isinstance(self.kv_cache, MHATokenToKVPool):
42
51
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
43
- self.kv_cache, hicache_ratio, hicache_size, page_size
52
+ self.kv_cache,
53
+ hicache_ratio,
54
+ hicache_size,
55
+ page_size,
56
+ hicache_mem_layout,
44
57
  )
45
58
  elif isinstance(self.kv_cache, MLATokenToKVPool):
46
59
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
47
- self.kv_cache, hicache_ratio, hicache_size, page_size
60
+ self.kv_cache,
61
+ hicache_ratio,
62
+ hicache_size,
63
+ page_size,
64
+ hicache_mem_layout,
48
65
  )
49
66
  else:
50
67
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -436,7 +453,7 @@ class HiRadixCache(RadixCache):
436
453
  last_host_node,
437
454
  fetched_token_ids,
438
455
  written_indices,
439
- hash_value[:min_completed_tokens],
456
+ hash_value[: min_completed_tokens // self.page_size],
440
457
  )
441
458
  if len(written_indices):
442
459
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
@@ -529,7 +546,7 @@ class HiRadixCache(RadixCache):
529
546
  prefix_len = self.key_match_fn(node.key, key)
530
547
  key = key[prefix_len:]
531
548
  host_value = host_value[prefix_len:]
532
- hash_value = hash_value[prefix_len:]
549
+ hash_value = hash_value[prefix_len // self.page_size :]
533
550
  matched_length += prefix_len
534
551
 
535
552
  if prefix_len < len(node.key):