sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ import uuid
|
|
29
29
|
from collections import deque
|
30
30
|
from contextlib import nullcontext
|
31
31
|
from datetime import datetime
|
32
|
+
from enum import Enum
|
32
33
|
from http import HTTPStatus
|
33
34
|
from typing import (
|
34
35
|
Any,
|
@@ -70,7 +71,6 @@ from sglang.srt.managers.io_struct import (
|
|
70
71
|
BatchMultimodalOut,
|
71
72
|
BatchStrOut,
|
72
73
|
BatchTokenIDOut,
|
73
|
-
BlockReqType,
|
74
74
|
CloseSessionReqInput,
|
75
75
|
ConfigureLoggingReq,
|
76
76
|
EmbeddingReqInput,
|
@@ -116,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
|
116
116
|
)
|
117
117
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
118
118
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
119
|
+
from sglang.srt.managers.scheduler import is_health_check_generate_req
|
119
120
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
120
121
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
121
122
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -202,13 +203,29 @@ class TokenizerManager:
|
|
202
203
|
|
203
204
|
if self.model_config.is_multimodal:
|
204
205
|
import_processors()
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
206
|
+
try:
|
207
|
+
_processor = get_processor(
|
208
|
+
server_args.tokenizer_path,
|
209
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
210
|
+
trust_remote_code=server_args.trust_remote_code,
|
211
|
+
revision=server_args.revision,
|
212
|
+
use_fast=not server_args.disable_fast_image_processor,
|
213
|
+
)
|
214
|
+
except ValueError as e:
|
215
|
+
error_message = str(e)
|
216
|
+
if "does not have a slow version" in error_message:
|
217
|
+
logger.info(
|
218
|
+
f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
|
219
|
+
)
|
220
|
+
_processor = get_processor(
|
221
|
+
server_args.tokenizer_path,
|
222
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
223
|
+
trust_remote_code=server_args.trust_remote_code,
|
224
|
+
revision=server_args.revision,
|
225
|
+
use_fast=True,
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
raise e
|
212
229
|
transport_mode = _determine_tensor_transport_mode(self.server_args)
|
213
230
|
|
214
231
|
# We want to parallelize the image pre-processing so we create an executor for it
|
@@ -225,10 +242,10 @@ class TokenizerManager:
|
|
225
242
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
226
243
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
227
244
|
else:
|
228
|
-
self.mm_processor = None
|
245
|
+
self.mm_processor = self.processor = None
|
229
246
|
|
230
247
|
if server_args.skip_tokenizer_init:
|
231
|
-
self.tokenizer =
|
248
|
+
self.tokenizer = None
|
232
249
|
else:
|
233
250
|
self.tokenizer = get_tokenizer(
|
234
251
|
server_args.tokenizer_path,
|
@@ -255,6 +272,7 @@ class TokenizerManager:
|
|
255
272
|
self.health_check_failed = False
|
256
273
|
self.gracefully_exit = False
|
257
274
|
self.last_receive_tstamp = 0
|
275
|
+
self.server_status = ServerStatus.Starting
|
258
276
|
|
259
277
|
# Dumping
|
260
278
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -1069,38 +1087,56 @@ class TokenizerManager:
|
|
1069
1087
|
_: Optional[fastapi.Request] = None,
|
1070
1088
|
) -> LoadLoRAAdapterReqOutput:
|
1071
1089
|
self.auto_create_handle_loop()
|
1072
|
-
if not self.server_args.enable_lora:
|
1073
|
-
raise ValueError(
|
1074
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1075
|
-
)
|
1076
1090
|
|
1077
|
-
|
1078
|
-
|
1079
|
-
|
1080
|
-
|
1081
|
-
|
1082
|
-
logger.info(
|
1083
|
-
"Start load Lora adapter. Lora name=%s, path=%s",
|
1084
|
-
obj.lora_name,
|
1085
|
-
obj.lora_path,
|
1086
|
-
)
|
1091
|
+
try:
|
1092
|
+
if not self.server_args.enable_lora:
|
1093
|
+
raise ValueError(
|
1094
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1095
|
+
)
|
1087
1096
|
|
1088
|
-
|
1089
|
-
#
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1097
|
+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1098
|
+
# with dp_size > 1.
|
1099
|
+
assert (
|
1100
|
+
self.server_args.dp_size == 1
|
1101
|
+
), "dp_size must be 1 for dynamic lora loading"
|
1102
|
+
logger.info(
|
1103
|
+
"Start load Lora adapter. Lora name=%s, path=%s",
|
1104
|
+
obj.lora_name,
|
1105
|
+
obj.lora_path,
|
1093
1106
|
)
|
1094
1107
|
|
1095
|
-
|
1096
|
-
|
1097
|
-
|
1108
|
+
async with self.lora_update_lock:
|
1109
|
+
if (
|
1110
|
+
self.server_args.max_loaded_loras is not None
|
1111
|
+
and self.lora_registry.num_registered_loras
|
1112
|
+
>= self.server_args.max_loaded_loras
|
1113
|
+
):
|
1114
|
+
raise ValueError(
|
1115
|
+
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
1116
|
+
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
1117
|
+
"Please unload some LoRA adapters before loading new ones."
|
1118
|
+
)
|
1098
1119
|
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1120
|
+
# Generate new uniquely identifiable LoRARef object.
|
1121
|
+
new_adapter = LoRARef(
|
1122
|
+
lora_name=obj.lora_name,
|
1123
|
+
lora_path=obj.lora_path,
|
1124
|
+
)
|
1102
1125
|
|
1103
|
-
|
1126
|
+
# Trigger the actual loading operation at the backend processes.
|
1127
|
+
obj.lora_id = new_adapter.lora_id
|
1128
|
+
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1129
|
+
|
1130
|
+
# Register the LoRA adapter only after loading is successful.
|
1131
|
+
if result.success:
|
1132
|
+
await self.lora_registry.register(new_adapter)
|
1133
|
+
|
1134
|
+
return result
|
1135
|
+
except ValueError as e:
|
1136
|
+
return LoadLoRAAdapterReqOutput(
|
1137
|
+
success=False,
|
1138
|
+
error_message=str(e),
|
1139
|
+
)
|
1104
1140
|
|
1105
1141
|
async def unload_lora_adapter(
|
1106
1142
|
self,
|
@@ -1108,37 +1144,41 @@ class TokenizerManager:
|
|
1108
1144
|
_: Optional[fastapi.Request] = None,
|
1109
1145
|
) -> UnloadLoRAAdapterReqOutput:
|
1110
1146
|
self.auto_create_handle_loop()
|
1111
|
-
if not self.server_args.enable_lora:
|
1112
|
-
raise ValueError(
|
1113
|
-
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1114
|
-
)
|
1115
1147
|
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1148
|
+
try:
|
1149
|
+
if not self.server_args.enable_lora:
|
1150
|
+
raise ValueError(
|
1151
|
+
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1152
|
+
)
|
1119
1153
|
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1154
|
+
assert (
|
1155
|
+
obj.lora_name is not None
|
1156
|
+
), "lora_name must be provided to unload LoRA adapter"
|
1157
|
+
|
1158
|
+
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1159
|
+
# with dp_size > 1.
|
1160
|
+
assert (
|
1161
|
+
self.server_args.dp_size == 1
|
1162
|
+
), "dp_size must be 1 for dynamic lora loading"
|
1163
|
+
logger.info(
|
1164
|
+
"Start unload Lora adapter. Lora name=%s",
|
1165
|
+
obj.lora_name,
|
1166
|
+
)
|
1129
1167
|
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1134
|
-
|
1168
|
+
async with self.lora_update_lock:
|
1169
|
+
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
1170
|
+
# from being started.
|
1171
|
+
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1172
|
+
obj.lora_id = lora_id
|
1135
1173
|
|
1136
|
-
|
1137
|
-
|
1138
|
-
|
1139
|
-
|
1174
|
+
# Initiate the actual unloading operation at the backend processes only after all
|
1175
|
+
# ongoing requests using this LoRA adapter are finished.
|
1176
|
+
await self.lora_registry.wait_for_unload(lora_id)
|
1177
|
+
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1140
1178
|
|
1141
|
-
|
1179
|
+
return result
|
1180
|
+
except ValueError as e:
|
1181
|
+
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
|
1142
1182
|
|
1143
1183
|
async def get_weights_by_name(
|
1144
1184
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -1767,6 +1807,8 @@ class TokenizerManager:
|
|
1767
1807
|
asyncio.create_task(asyncio.to_thread(background_task))
|
1768
1808
|
|
1769
1809
|
def _handle_abort_req(self, recv_obj):
|
1810
|
+
if is_health_check_generate_req(recv_obj):
|
1811
|
+
return
|
1770
1812
|
state = self.rid_to_state[recv_obj.rid]
|
1771
1813
|
state.finished = True
|
1772
1814
|
if recv_obj.finished_reason:
|
@@ -1901,6 +1943,16 @@ class TokenizerManager:
|
|
1901
1943
|
return scores
|
1902
1944
|
|
1903
1945
|
|
1946
|
+
class ServerStatus(Enum):
|
1947
|
+
Up = "Up"
|
1948
|
+
Starting = "Starting"
|
1949
|
+
UnHealthy = "UnHealthy"
|
1950
|
+
Crashed = "Crashed"
|
1951
|
+
|
1952
|
+
def is_healthy(self) -> bool:
|
1953
|
+
return self == ServerStatus.Up
|
1954
|
+
|
1955
|
+
|
1904
1956
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
1905
1957
|
is_cross_node = server_args.dist_init_addr
|
1906
1958
|
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import logging
|
2
|
+
import multiprocessing as mp
|
2
3
|
from http import HTTPStatus
|
3
|
-
from typing import Optional
|
4
|
+
from typing import Dict, List, Optional
|
4
5
|
|
5
6
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
7
|
|
@@ -38,3 +39,46 @@ def validate_input_length(
|
|
38
39
|
return error_msg
|
39
40
|
|
40
41
|
return None
|
42
|
+
|
43
|
+
|
44
|
+
class DPBalanceMeta:
|
45
|
+
"""
|
46
|
+
This class will be use in scheduler and dp controller
|
47
|
+
"""
|
48
|
+
|
49
|
+
def __init__(self, num_workers: int):
|
50
|
+
self.num_workers = num_workers
|
51
|
+
self._manager = mp.Manager()
|
52
|
+
self.mutex = self._manager.Lock()
|
53
|
+
|
54
|
+
init_local_tokens = [0] * self.num_workers
|
55
|
+
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
56
|
+
|
57
|
+
self.shared_state = self._manager.Namespace()
|
58
|
+
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
59
|
+
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
60
|
+
|
61
|
+
def destructor(self):
|
62
|
+
# we must destructor this class manually
|
63
|
+
self._manager.shutdown()
|
64
|
+
|
65
|
+
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
66
|
+
return [dict(d) for d in self.shared_state.onfly_info]
|
67
|
+
|
68
|
+
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
69
|
+
self.shared_state.onfly_info = data
|
70
|
+
|
71
|
+
def get_shared_local_tokens(self) -> List[int]:
|
72
|
+
return list(self.shared_state.local_tokens)
|
73
|
+
|
74
|
+
def set_shared_local_tokens(self, data: List[int]):
|
75
|
+
self.shared_state.local_tokens = data
|
76
|
+
|
77
|
+
def __getstate__(self):
|
78
|
+
state = self.__dict__.copy()
|
79
|
+
del state["_manager"]
|
80
|
+
return state
|
81
|
+
|
82
|
+
def __setstate__(self, state):
|
83
|
+
self.__dict__.update(state)
|
84
|
+
self._manager = None
|
@@ -0,0 +1,182 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.utils.cpp_extension import load
|
8
|
+
|
9
|
+
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
10
|
+
radix_tree_cpp = load(
|
11
|
+
name="radix_tree_cpp",
|
12
|
+
sources=[
|
13
|
+
f"{_abs_path}/tree_v2_binding.cpp",
|
14
|
+
f"{_abs_path}/tree_v2_debug.cpp",
|
15
|
+
f"{_abs_path}/tree_v2.cpp",
|
16
|
+
],
|
17
|
+
extra_cflags=["-O3", "-std=c++20"],
|
18
|
+
)
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
|
22
|
+
class TreeNodeCpp:
|
23
|
+
"""
|
24
|
+
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
|
25
|
+
"""
|
26
|
+
|
27
|
+
class IOHandle:
|
28
|
+
"""
|
29
|
+
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
|
30
|
+
"""
|
31
|
+
|
32
|
+
class RadixTreeCpp:
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
disabled: bool,
|
36
|
+
host_size: Optional[int],
|
37
|
+
page_size: int,
|
38
|
+
write_through_threshold: int,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Initializes the RadixTreeCpp instance.
|
42
|
+
Args:
|
43
|
+
disabled (bool): If True, the radix tree is disabled.
|
44
|
+
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
|
45
|
+
page_size (int): Size of the page for the radix tree.
|
46
|
+
write_through_threshold (int): Threshold for writing through from GPU to CPU.
|
47
|
+
"""
|
48
|
+
self.tree = radix_tree_cpp.RadixTree( # type: ignore
|
49
|
+
disabled, host_size, page_size, write_through_threshold
|
50
|
+
)
|
51
|
+
|
52
|
+
def match_prefix(
|
53
|
+
self, prefix: List[int]
|
54
|
+
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
|
55
|
+
"""
|
56
|
+
Matches a prefix in the radix tree.
|
57
|
+
Args:
|
58
|
+
prefix (List[int]): The prefix to match.
|
59
|
+
Returns:
|
60
|
+
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
|
61
|
+
0. A list of indices that is matched by the prefix on the GPU.
|
62
|
+
1. Sum length of the indices matched on the CPU.
|
63
|
+
2. The last node of the prefix matched on the GPU.
|
64
|
+
3. The last node of the prefix matched on the CPU.
|
65
|
+
"""
|
66
|
+
return self.tree.match_prefix(prefix)
|
67
|
+
|
68
|
+
def evict(self, num_tokens: int) -> List[torch.Tensor]:
|
69
|
+
"""
|
70
|
+
Evicts a number of tokens from the radix tree.
|
71
|
+
Args:
|
72
|
+
num_tokens (int): The number of tokens to evict.
|
73
|
+
Returns:
|
74
|
+
List[torch.Tensor]: A list of indices that were evicted.
|
75
|
+
"""
|
76
|
+
return self.tree.evict(num_tokens)
|
77
|
+
|
78
|
+
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
|
79
|
+
"""
|
80
|
+
Locks or unlocks a reference to a tree node.
|
81
|
+
After locking, the node will not be evicted from the radix tree.
|
82
|
+
Args:
|
83
|
+
handle (TreeNodeCpp): The tree node to lock or unlock.
|
84
|
+
lock (bool): If True, locks the node; if False, unlocks it.
|
85
|
+
"""
|
86
|
+
return self.tree.lock_ref(handle, lock)
|
87
|
+
|
88
|
+
def writing_through(
|
89
|
+
self, key: List[int], indices: torch.Tensor
|
90
|
+
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
91
|
+
"""
|
92
|
+
Inserts a key-value pair into the radix tree and perform write-through check.
|
93
|
+
Args:
|
94
|
+
key (List[int]): The key to insert.
|
95
|
+
indices (torch.Tensor): The value associated with the key.
|
96
|
+
Returns:
|
97
|
+
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
98
|
+
0. A list of (IOHandle, device indices, host indices) tuples.
|
99
|
+
These IOhandles require write-through to the CPU in python side.
|
100
|
+
1. The number of indices that are matched on device.
|
101
|
+
"""
|
102
|
+
return self.tree.writing_through(key, indices)
|
103
|
+
|
104
|
+
def loading_onboard(
|
105
|
+
self,
|
106
|
+
host_node: TreeNodeCpp,
|
107
|
+
new_device_indices: torch.Tensor,
|
108
|
+
) -> Tuple[IOHandle, List[torch.Tensor]]:
|
109
|
+
"""
|
110
|
+
Updates the device indices of tree nodes within a range on the tree.
|
111
|
+
Args:
|
112
|
+
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
|
113
|
+
new_device_indices (torch.Tensor): The new device indices to set.
|
114
|
+
The length of this tensor must be exactly host indices length.
|
115
|
+
Returns:
|
116
|
+
Tuple[IOHandle, List[torch.Tensor]]:
|
117
|
+
0. An IOHandle that requires loading to the CPU in python side.
|
118
|
+
1. A list of host indices corresponding to the new device indices.
|
119
|
+
"""
|
120
|
+
return self.tree.loading_onboard(host_node, new_device_indices)
|
121
|
+
|
122
|
+
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
|
123
|
+
"""
|
124
|
+
Commits the write-through process for a tree node.
|
125
|
+
Args:
|
126
|
+
handle (IOHandle): The IOHandle to commit.
|
127
|
+
success (bool): If True, commits the write-through; if False, just indicates failure.
|
128
|
+
"""
|
129
|
+
return self.tree.commit_writing_through(handle, success)
|
130
|
+
|
131
|
+
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
|
132
|
+
"""
|
133
|
+
Commits the load onboard process for tree nodes within a range on the tree.
|
134
|
+
Args:
|
135
|
+
handle (IOHandle): The IOHandle to commit.
|
136
|
+
success (bool): If True, commits the load-onboard; if False, just indicates failure.
|
137
|
+
"""
|
138
|
+
return self.tree.commit_loading_onboard(handle, success)
|
139
|
+
|
140
|
+
def evictable_size(self) -> int:
|
141
|
+
"""
|
142
|
+
Returns the size of the evictable part of the radix tree.
|
143
|
+
This is the size of the part that can be evicted from the GPU (ref_count = 0).
|
144
|
+
Returns:
|
145
|
+
int: The size of the evictable part.
|
146
|
+
"""
|
147
|
+
return self.tree.evictable_size()
|
148
|
+
|
149
|
+
def protected_size(self) -> int:
|
150
|
+
"""
|
151
|
+
Returns the size of the protected part of the radix tree.
|
152
|
+
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
|
153
|
+
Returns:
|
154
|
+
int: The size of the protected part.
|
155
|
+
"""
|
156
|
+
return self.tree.protected_size()
|
157
|
+
|
158
|
+
def total_size(self) -> int:
|
159
|
+
"""
|
160
|
+
Returns the total size of the radix tree (including CPU nodes).
|
161
|
+
Returns:
|
162
|
+
int: The total size of the radix tree.
|
163
|
+
"""
|
164
|
+
return self.tree.total_size()
|
165
|
+
|
166
|
+
def reset(self) -> None:
|
167
|
+
"""
|
168
|
+
Resets the radix tree, clearing all nodes and indices.
|
169
|
+
"""
|
170
|
+
return self.tree.reset()
|
171
|
+
|
172
|
+
def debug_print(self) -> None:
|
173
|
+
"""
|
174
|
+
Prints the internal state of the radix tree for debugging purposes.
|
175
|
+
"""
|
176
|
+
return self.tree.debug_print()
|
177
|
+
|
178
|
+
else:
|
179
|
+
# Real implementation of the classes for runtime
|
180
|
+
RadixTreeCpp = radix_tree_cpp.RadixTree
|
181
|
+
TreeNodeCpp = object
|
182
|
+
IOHandle = object
|
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
|
|
33
33
|
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
34
34
|
"""
|
35
35
|
|
36
|
-
# todo,
|
37
|
-
# potentially pass model and TP configs into storage backend
|
36
|
+
# todo, potentially pass model and TP configs into storage backend
|
38
37
|
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
39
38
|
|
40
39
|
@abstractmethod
|
@@ -117,26 +116,28 @@ class HiCacheFile(HiCacheStorage):
|
|
117
116
|
def get(
|
118
117
|
self,
|
119
118
|
key: str,
|
120
|
-
target_location:
|
119
|
+
target_location: torch.Tensor,
|
121
120
|
target_sizes: Optional[Any] = None,
|
122
121
|
) -> torch.Tensor | None:
|
123
122
|
key = self._get_suffixed_key(key)
|
124
123
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
125
124
|
try:
|
126
|
-
#
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
125
|
+
# Load directly into target_location's memory buffer
|
126
|
+
with open(tensor_path, "rb") as f:
|
127
|
+
target_location.set_(
|
128
|
+
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
129
|
+
.reshape(target_location.shape)
|
130
|
+
.untyped_storage()
|
131
|
+
)
|
132
|
+
return target_location
|
133
133
|
except FileNotFoundError:
|
134
|
+
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
134
135
|
return None
|
135
136
|
|
136
137
|
def batch_get(
|
137
138
|
self,
|
138
139
|
keys: List[str],
|
139
|
-
target_locations:
|
140
|
+
target_locations: List[torch.Tensor],
|
140
141
|
target_sizes: Optional[Any] = None,
|
141
142
|
) -> List[torch.Tensor | None]:
|
142
143
|
return [
|
@@ -159,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
|
|
159
160
|
logger.debug(f"Key {key} already exists. Skipped.")
|
160
161
|
return True
|
161
162
|
try:
|
162
|
-
torch.
|
163
|
+
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
163
164
|
return True
|
164
165
|
except Exception as e:
|
165
166
|
logger.error(f"Failed to save tensor {key}: {e}")
|
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
|
|
35
35
|
hicache_size: int,
|
36
36
|
hicache_write_policy: str,
|
37
37
|
hicache_io_backend: str,
|
38
|
+
hicache_mem_layout: str,
|
38
39
|
hicache_storage_backend: Optional[str] = None,
|
39
40
|
):
|
41
|
+
|
42
|
+
if hicache_io_backend == "direct":
|
43
|
+
if hicache_mem_layout == "page_first":
|
44
|
+
hicache_mem_layout = "layer_first"
|
45
|
+
logger.warning(
|
46
|
+
"Page first layout is not supported with direct IO backend, switching to layer first layout"
|
47
|
+
)
|
48
|
+
|
40
49
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
41
50
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
42
51
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
43
|
-
self.kv_cache,
|
52
|
+
self.kv_cache,
|
53
|
+
hicache_ratio,
|
54
|
+
hicache_size,
|
55
|
+
page_size,
|
56
|
+
hicache_mem_layout,
|
44
57
|
)
|
45
58
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
46
59
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
47
|
-
self.kv_cache,
|
60
|
+
self.kv_cache,
|
61
|
+
hicache_ratio,
|
62
|
+
hicache_size,
|
63
|
+
page_size,
|
64
|
+
hicache_mem_layout,
|
48
65
|
)
|
49
66
|
else:
|
50
67
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
@@ -436,7 +453,7 @@ class HiRadixCache(RadixCache):
|
|
436
453
|
last_host_node,
|
437
454
|
fetched_token_ids,
|
438
455
|
written_indices,
|
439
|
-
hash_value[:min_completed_tokens],
|
456
|
+
hash_value[: min_completed_tokens // self.page_size],
|
440
457
|
)
|
441
458
|
if len(written_indices):
|
442
459
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
@@ -529,7 +546,7 @@ class HiRadixCache(RadixCache):
|
|
529
546
|
prefix_len = self.key_match_fn(node.key, key)
|
530
547
|
key = key[prefix_len:]
|
531
548
|
host_value = host_value[prefix_len:]
|
532
|
-
hash_value = hash_value[prefix_len:]
|
549
|
+
hash_value = hash_value[prefix_len // self.page_size :]
|
533
550
|
matched_length += prefix_len
|
534
551
|
|
535
552
|
if prefix_len < len(node.key):
|