sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +72 -10
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +6 -16
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +582 -125
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +79 -6
- sglang/srt/layers/quantization/__init__.py +137 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +44 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -127
- sglang/srt/managers/scheduler.py +29 -23
- sglang/srt/managers/tokenizer_manager.py +1 -2
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +16 -13
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +64 -59
- sglang/srt/model_loader/loader.py +19 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +568 -0
- sglang/srt/models/deepseek_janus_pro.py +12 -17
- sglang/srt/models/deepseek_v2.py +339 -123
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +20 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +106 -93
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +120 -25
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +94 -25
- sglang/srt/utils.py +137 -51
- sglang/test/runners.py +27 -2
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +14 -27
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
112
112
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
113
113
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
114
114
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
115
|
-
from sglang.srt.model_executor.forward_batch_info import
|
115
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
116
116
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
117
117
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
118
118
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -379,7 +379,7 @@ class Scheduler(
|
|
379
379
|
# Init profiler
|
380
380
|
self.torch_profiler = None
|
381
381
|
self.torch_profiler_output_dir: Optional[str] = None
|
382
|
-
self.
|
382
|
+
self.profiler_activities: Optional[List[str]] = None
|
383
383
|
self.profiler_target_forward_ct: Optional[int] = None
|
384
384
|
|
385
385
|
# Init metrics stats
|
@@ -1110,7 +1110,7 @@ class Scheduler(
|
|
1110
1110
|
)
|
1111
1111
|
if memory_leak:
|
1112
1112
|
msg = (
|
1113
|
-
"
|
1113
|
+
"token_to_kv_pool_allocator memory leak detected! "
|
1114
1114
|
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
|
1115
1115
|
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
|
1116
1116
|
f"{self.tree_cache.evictable_size()=}\n"
|
@@ -1121,7 +1121,7 @@ class Scheduler(
|
|
1121
1121
|
|
1122
1122
|
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
|
1123
1123
|
msg = (
|
1124
|
-
"
|
1124
|
+
"req_to_token_pool memory leak detected!"
|
1125
1125
|
f"available_size={len(self.req_to_token_pool.free_slots)}, "
|
1126
1126
|
f"total_size={self.req_to_token_pool.size}\n"
|
1127
1127
|
)
|
@@ -1186,7 +1186,7 @@ class Scheduler(
|
|
1186
1186
|
ret = None
|
1187
1187
|
|
1188
1188
|
# Handle DP attention
|
1189
|
-
if self.server_args.enable_dp_attention:
|
1189
|
+
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
|
1190
1190
|
ret, _ = self.prepare_dp_attn_batch(ret)
|
1191
1191
|
|
1192
1192
|
return ret
|
@@ -1282,7 +1282,7 @@ class Scheduler(
|
|
1282
1282
|
]
|
1283
1283
|
|
1284
1284
|
if self.enable_hierarchical_cache:
|
1285
|
-
self.tree_cache.
|
1285
|
+
self.tree_cache.ready_to_load_cache()
|
1286
1286
|
|
1287
1287
|
if adder.new_chunked_req is not None:
|
1288
1288
|
assert self.chunked_req is None
|
@@ -1703,18 +1703,12 @@ class Scheduler(
|
|
1703
1703
|
def save_remote_model(self, params):
|
1704
1704
|
url = params["url"]
|
1705
1705
|
|
1706
|
-
|
1707
|
-
worker = self.tp_worker.worker
|
1708
|
-
else:
|
1709
|
-
worker = self.tp_worker
|
1706
|
+
worker = self.tp_worker.worker
|
1710
1707
|
|
1711
1708
|
worker.model_runner.save_remote_model(url)
|
1712
1709
|
|
1713
1710
|
def save_sharded_model(self, params):
|
1714
|
-
|
1715
|
-
worker = self.tp_worker.worker
|
1716
|
-
else:
|
1717
|
-
worker = self.tp_worker
|
1711
|
+
worker = self.tp_worker.worker
|
1718
1712
|
|
1719
1713
|
worker.model_runner.save_sharded_model(
|
1720
1714
|
path=params["path"],
|
@@ -1813,7 +1807,11 @@ class Scheduler(
|
|
1813
1807
|
def profile(self, recv_req: ProfileReq):
|
1814
1808
|
if recv_req.type == ProfileReqType.START_PROFILE:
|
1815
1809
|
return self.start_profile(
|
1816
|
-
recv_req.output_dir,
|
1810
|
+
recv_req.output_dir,
|
1811
|
+
recv_req.num_steps,
|
1812
|
+
recv_req.activities,
|
1813
|
+
recv_req.with_stack,
|
1814
|
+
recv_req.record_shapes,
|
1817
1815
|
)
|
1818
1816
|
else:
|
1819
1817
|
return self.stop_profile()
|
@@ -1823,8 +1821,10 @@ class Scheduler(
|
|
1823
1821
|
output_dir: Optional[str],
|
1824
1822
|
num_steps: Optional[int],
|
1825
1823
|
activities: Optional[List[str]],
|
1824
|
+
with_stack: Optional[bool],
|
1825
|
+
record_shapes: Optional[bool],
|
1826
1826
|
) -> None:
|
1827
|
-
if self.
|
1827
|
+
if self.profiler_activities:
|
1828
1828
|
return ProfileReqOutput(
|
1829
1829
|
success=False,
|
1830
1830
|
message="Profiling is already in progress. Call /stop_profile first.",
|
@@ -1836,7 +1836,7 @@ class Scheduler(
|
|
1836
1836
|
activities = ["CPU", "GPU"]
|
1837
1837
|
|
1838
1838
|
self.torch_profiler_output_dir = output_dir
|
1839
|
-
self.
|
1839
|
+
self.profiler_activities = activities
|
1840
1840
|
logger.info(
|
1841
1841
|
"Profiling starts. Traces will be saved to: %s",
|
1842
1842
|
self.torch_profiler_output_dir,
|
@@ -1853,13 +1853,17 @@ class Scheduler(
|
|
1853
1853
|
if torchprof_activities:
|
1854
1854
|
self.torch_profiler = torch.profiler.profile(
|
1855
1855
|
activities=torchprof_activities,
|
1856
|
-
with_stack=True,
|
1856
|
+
with_stack=with_stack if with_stack is not None else True,
|
1857
|
+
record_shapes=record_shapes if record_shapes is not None else False,
|
1857
1858
|
)
|
1858
1859
|
self.torch_profiler.start()
|
1859
1860
|
|
1860
1861
|
if "MEM" in activities:
|
1861
1862
|
torch.cuda.memory._record_memory_history(max_entries=100000)
|
1862
1863
|
|
1864
|
+
if "CUDA_PROFILER" in activities:
|
1865
|
+
torch.cuda.cudart().cudaProfilerStart()
|
1866
|
+
|
1863
1867
|
if num_steps:
|
1864
1868
|
self.profiler_target_forward_ct = self.forward_ct + num_steps
|
1865
1869
|
# The caller will be notified when reaching profiler_target_forward_ct
|
@@ -1868,7 +1872,7 @@ class Scheduler(
|
|
1868
1872
|
return ProfileReqOutput(success=True, message="Succeeded")
|
1869
1873
|
|
1870
1874
|
def stop_profile(self) -> None:
|
1871
|
-
if self.
|
1875
|
+
if self.profiler_activities is None:
|
1872
1876
|
return
|
1873
1877
|
|
1874
1878
|
logger.info("Stop profiling...")
|
@@ -1881,21 +1885,24 @@ class Scheduler(
|
|
1881
1885
|
)
|
1882
1886
|
)
|
1883
1887
|
|
1884
|
-
if "MEM" in self.
|
1888
|
+
if "MEM" in self.profiler_activities:
|
1885
1889
|
memory_profile_path = os.path.join(
|
1886
|
-
self.
|
1890
|
+
self.torch_profiler_output_dir,
|
1887
1891
|
str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
|
1888
1892
|
)
|
1889
1893
|
torch.cuda.memory._dump_snapshot(memory_profile_path)
|
1890
1894
|
torch.cuda.memory._record_memory_history(enabled=None)
|
1891
1895
|
|
1896
|
+
if "CUDA_PROFILER" in self.profiler_activities:
|
1897
|
+
torch.cuda.cudart().cudaProfilerStop()
|
1898
|
+
|
1892
1899
|
logger.info(
|
1893
1900
|
"Profiling done. Traces are saved to: %s",
|
1894
1901
|
self.torch_profiler_output_dir,
|
1895
1902
|
)
|
1896
1903
|
self.torch_profiler = None
|
1897
1904
|
self.torch_profiler_output_dir = None
|
1898
|
-
self.
|
1905
|
+
self.profiler_activities = None
|
1899
1906
|
|
1900
1907
|
if self.profiler_target_forward_ct:
|
1901
1908
|
self.send_to_tokenizer.send_pyobj(
|
@@ -1963,7 +1970,6 @@ def run_scheduler_process(
|
|
1963
1970
|
dp_rank: Optional[int],
|
1964
1971
|
pipe_writer,
|
1965
1972
|
):
|
1966
|
-
|
1967
1973
|
# Generate the prefix
|
1968
1974
|
if dp_rank is None:
|
1969
1975
|
prefix = f" TP{tp_rank}"
|
@@ -261,7 +261,6 @@ class TokenizerManager:
|
|
261
261
|
self.start_profile_communicator = _Communicator(
|
262
262
|
self.send_to_scheduler, server_args.dp_size
|
263
263
|
)
|
264
|
-
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
265
264
|
self.get_internal_state_communicator = _Communicator(
|
266
265
|
self.send_to_scheduler, server_args.dp_size
|
267
266
|
)
|
@@ -737,7 +736,7 @@ class TokenizerManager:
|
|
737
736
|
self.auto_create_handle_loop()
|
738
737
|
assert (
|
739
738
|
self.server_args.dp_size == 1
|
740
|
-
), "dp_size must be for update weights from distributed"
|
739
|
+
), "dp_size must be 1 for update weights from distributed"
|
741
740
|
|
742
741
|
# This means that weight sync
|
743
742
|
# cannot run while requests are in progress.
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -132,6 +132,9 @@ class TpModelWorker:
|
|
132
132
|
)[0]
|
133
133
|
set_random_seed(self.random_seed)
|
134
134
|
|
135
|
+
# A reference make this class has the same member as TpModelWorkerClient
|
136
|
+
self.worker = self
|
137
|
+
|
135
138
|
def get_worker_info(self):
|
136
139
|
return (
|
137
140
|
self.max_total_num_tokens,
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,11 +1,6 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
|
-
import time
|
4
|
-
from collections import defaultdict
|
5
2
|
from http import HTTPStatus
|
6
|
-
from typing import
|
7
|
-
|
8
|
-
import torch
|
3
|
+
from typing import Optional
|
9
4
|
|
10
5
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
11
6
|
|
@@ -16,7 +16,6 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
16
16
|
TokenToKVPoolAllocator,
|
17
17
|
)
|
18
18
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
19
|
-
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
|
20
19
|
|
21
20
|
logger = logging.getLogger(__name__)
|
22
21
|
|
@@ -31,29 +30,25 @@ class HiRadixCache(RadixCache):
|
|
31
30
|
page_size: int,
|
32
31
|
hicache_ratio: float,
|
33
32
|
):
|
34
|
-
if page_size != 1:
|
35
|
-
raise ValueError(
|
36
|
-
"Page size larger than 1 is not yet supported in HiRadixCache."
|
37
|
-
)
|
38
33
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
39
34
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
40
35
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
41
|
-
self.kv_cache, hicache_ratio
|
36
|
+
self.kv_cache, hicache_ratio, page_size
|
42
37
|
)
|
43
38
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
44
39
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
45
|
-
self.kv_cache, hicache_ratio
|
40
|
+
self.kv_cache, hicache_ratio, page_size
|
46
41
|
)
|
47
42
|
else:
|
48
|
-
raise ValueError(f"
|
43
|
+
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
49
44
|
|
50
45
|
self.tp_group = tp_cache_group
|
51
|
-
self.page_size = page_size
|
52
46
|
|
53
47
|
self.load_cache_event = threading.Event()
|
54
48
|
self.cache_controller = HiCacheController(
|
55
49
|
token_to_kv_pool_allocator,
|
56
50
|
self.token_to_kv_pool_host,
|
51
|
+
page_size,
|
57
52
|
load_cache_event=self.load_cache_event,
|
58
53
|
)
|
59
54
|
|
@@ -65,7 +60,7 @@ class HiRadixCache(RadixCache):
|
|
65
60
|
self.write_through_threshold = 1
|
66
61
|
self.load_back_threshold = 10
|
67
62
|
super().__init__(
|
68
|
-
req_to_token_pool, token_to_kv_pool_allocator,
|
63
|
+
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
|
69
64
|
)
|
70
65
|
|
71
66
|
def reset(self):
|
@@ -210,9 +205,9 @@ class HiRadixCache(RadixCache):
|
|
210
205
|
# only evict the host value of evicted nodes
|
211
206
|
if not x.evicted:
|
212
207
|
continue
|
213
|
-
assert x.lock_ref == 0 and x.host_value is not None
|
214
208
|
|
215
|
-
|
209
|
+
num_evicted += self.cache_controller.evict_host(x.host_value)
|
210
|
+
|
216
211
|
for k, v in x.parent.children.items():
|
217
212
|
if v == x:
|
218
213
|
break
|
@@ -299,18 +294,26 @@ class HiRadixCache(RadixCache):
|
|
299
294
|
|
300
295
|
return last_node, prefix_indices
|
301
296
|
|
302
|
-
def
|
297
|
+
def ready_to_load_cache(self):
|
303
298
|
self.load_cache_event.set()
|
304
299
|
|
305
300
|
def match_prefix(self, key: List[int], include_evicted=False, **kwargs):
|
306
|
-
|
307
|
-
|
301
|
+
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
302
|
+
if self.disable or len(key) == 0:
|
303
|
+
if include_evicted:
|
304
|
+
return empty_value, self.root_node, self.root_node
|
305
|
+
else:
|
306
|
+
return empty_value, self.root_node
|
307
|
+
|
308
|
+
if self.page_size != 1:
|
309
|
+
page_aligned_len = len(key) // self.page_size * self.page_size
|
310
|
+
key = key[:page_aligned_len]
|
308
311
|
|
309
312
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
310
313
|
if value:
|
311
314
|
value = torch.cat(value)
|
312
315
|
else:
|
313
|
-
value =
|
316
|
+
value = empty_value
|
314
317
|
|
315
318
|
last_node_global = last_node
|
316
319
|
while last_node.evicted:
|
@@ -323,11 +326,13 @@ class HiRadixCache(RadixCache):
|
|
323
326
|
|
324
327
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
325
328
|
node.last_access_time = time.time()
|
329
|
+
child_key = self.get_child_key_fn(key)
|
326
330
|
value = []
|
327
|
-
|
328
|
-
|
331
|
+
|
332
|
+
while len(key) > 0 and child_key in node.children.keys():
|
333
|
+
child = node.children[child_key]
|
329
334
|
child.last_access_time = time.time()
|
330
|
-
prefix_len =
|
335
|
+
prefix_len = self.key_match_fn(child.key, key)
|
331
336
|
if prefix_len < len(child.key):
|
332
337
|
new_node = self._split_node(child.key, child, prefix_len)
|
333
338
|
if not new_node.evicted:
|
@@ -339,12 +344,16 @@ class HiRadixCache(RadixCache):
|
|
339
344
|
value.append(child.value)
|
340
345
|
node = child
|
341
346
|
key = key[prefix_len:]
|
347
|
+
|
348
|
+
if len(key):
|
349
|
+
child_key = self.get_child_key_fn(key)
|
350
|
+
|
342
351
|
return value, node
|
343
352
|
|
344
353
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
345
354
|
# child node split into new_node -> child
|
346
355
|
new_node = TreeNode()
|
347
|
-
new_node.children = {key[split_len]: child}
|
356
|
+
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
348
357
|
new_node.parent = child.parent
|
349
358
|
new_node.lock_ref = child.lock_ref
|
350
359
|
new_node.key = child.key[:split_len]
|
@@ -361,7 +370,7 @@ class HiRadixCache(RadixCache):
|
|
361
370
|
child.host_value = child.host_value[split_len:]
|
362
371
|
child.parent = new_node
|
363
372
|
child.key = child.key[split_len:]
|
364
|
-
new_node.parent.children[key
|
373
|
+
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
365
374
|
return new_node
|
366
375
|
|
367
376
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
@@ -369,52 +378,53 @@ class HiRadixCache(RadixCache):
|
|
369
378
|
if len(key) == 0:
|
370
379
|
return 0
|
371
380
|
|
372
|
-
|
373
|
-
|
374
|
-
prefix_len = _key_match(child.key, key)
|
381
|
+
child_key = self.get_child_key_fn(key)
|
382
|
+
total_prefix_length = 0
|
375
383
|
|
376
|
-
|
377
|
-
|
384
|
+
while len(key) > 0 and child_key in node.children.keys():
|
385
|
+
node = node.children[child_key]
|
386
|
+
node.last_access_time = time.time()
|
387
|
+
prefix_len = self.key_match_fn(node.key, key)
|
388
|
+
|
389
|
+
if prefix_len == len(node.key):
|
390
|
+
if node.evicted:
|
378
391
|
# change the reference if the node is evicted
|
379
392
|
# this often happens in the case of KV cache recomputation
|
380
|
-
|
381
|
-
self.token_to_kv_pool_host.update_synced(
|
382
|
-
self.evictable_size_ += len(value
|
383
|
-
return self._insert_helper(
|
384
|
-
child, key[prefix_len:], value[prefix_len:]
|
385
|
-
)
|
393
|
+
node.value = value[:prefix_len]
|
394
|
+
self.token_to_kv_pool_host.update_synced(node.host_value)
|
395
|
+
self.evictable_size_ += len(node.value)
|
386
396
|
else:
|
387
|
-
self.inc_hit_count(
|
388
|
-
|
389
|
-
child, key[prefix_len:], value[prefix_len:]
|
390
|
-
)
|
391
|
-
|
392
|
-
# partial match, split the node
|
393
|
-
new_node = self._split_node(child.key, child, prefix_len)
|
394
|
-
if new_node.evicted:
|
395
|
-
new_node.value = value[:prefix_len]
|
396
|
-
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
397
|
-
self.evictable_size_ += len(new_node.value)
|
398
|
-
return self._insert_helper(
|
399
|
-
new_node, key[prefix_len:], value[prefix_len:]
|
400
|
-
)
|
397
|
+
self.inc_hit_count(node)
|
398
|
+
total_prefix_length += prefix_len
|
401
399
|
else:
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
400
|
+
# partial match, split the node
|
401
|
+
new_node = self._split_node(node.key, node, prefix_len)
|
402
|
+
if new_node.evicted:
|
403
|
+
new_node.value = value[:prefix_len]
|
404
|
+
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
405
|
+
self.evictable_size_ += len(new_node.value)
|
406
|
+
else:
|
407
|
+
self.inc_hit_count(new_node)
|
408
|
+
total_prefix_length += prefix_len
|
409
|
+
node = new_node
|
410
|
+
|
411
|
+
key = key[prefix_len:]
|
412
|
+
value = value[prefix_len:]
|
413
|
+
|
414
|
+
if len(key):
|
415
|
+
child_key = self.get_child_key_fn(key)
|
406
416
|
|
407
417
|
if len(key):
|
408
418
|
new_node = TreeNode()
|
409
419
|
new_node.parent = node
|
410
420
|
new_node.key = key
|
411
421
|
new_node.value = value
|
412
|
-
node.children[
|
422
|
+
node.children[child_key] = new_node
|
413
423
|
self.evictable_size_ += len(value)
|
414
424
|
|
415
425
|
if self.cache_controller.write_policy == "write_through":
|
416
426
|
self.write_backup(new_node)
|
417
|
-
return
|
427
|
+
return total_prefix_length
|
418
428
|
|
419
429
|
def _collect_leaves_device(self):
|
420
430
|
def is_leaf(node):
|
@@ -185,6 +185,12 @@ class TokenToKVPoolAllocator:
|
|
185
185
|
if self.free_group:
|
186
186
|
self.free(torch.cat(self.free_group))
|
187
187
|
|
188
|
+
def backup_state(self):
|
189
|
+
return self.free_slots
|
190
|
+
|
191
|
+
def restore_state(self, free_slots):
|
192
|
+
self.free_slots = free_slots
|
193
|
+
|
188
194
|
def clear(self):
|
189
195
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
190
196
|
self.free_slots = torch.arange(
|
@@ -602,8 +608,9 @@ class HostKVCache(abc.ABC):
|
|
602
608
|
self,
|
603
609
|
device_pool: MHATokenToKVPool,
|
604
610
|
host_to_device_ratio: float,
|
605
|
-
pin_memory: bool
|
606
|
-
device: str
|
611
|
+
pin_memory: bool,
|
612
|
+
device: str,
|
613
|
+
page_size: int,
|
607
614
|
):
|
608
615
|
assert (
|
609
616
|
host_to_device_ratio >= 1
|
@@ -614,8 +621,11 @@ class HostKVCache(abc.ABC):
|
|
614
621
|
self.host_to_device_ratio = host_to_device_ratio
|
615
622
|
self.pin_memory = pin_memory
|
616
623
|
self.device = device
|
624
|
+
self.page_size = page_size
|
617
625
|
|
618
626
|
self.size = int(device_pool.size * host_to_device_ratio)
|
627
|
+
# Align the host memory pool size to the page size
|
628
|
+
self.size = self.size - (self.size % self.page_size)
|
619
629
|
self.dtype = device_pool.store_dtype
|
620
630
|
self.size_per_token = self.get_size_per_token()
|
621
631
|
|
@@ -769,10 +779,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
769
779
|
self,
|
770
780
|
device_pool: MHATokenToKVPool,
|
771
781
|
host_to_device_ratio: float,
|
772
|
-
|
782
|
+
page_size: int,
|
783
|
+
pin_memory: bool = True,
|
773
784
|
device: str = "cpu",
|
774
785
|
):
|
775
|
-
super().__init__(
|
786
|
+
super().__init__(
|
787
|
+
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
788
|
+
)
|
776
789
|
|
777
790
|
def get_size_per_token(self):
|
778
791
|
self.head_num = self.device_pool.head_num
|
@@ -805,16 +818,48 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
805
818
|
def assign_flat_data(self, indices, flat_data):
|
806
819
|
self.kv_buffer[:, :, indices] = flat_data
|
807
820
|
|
821
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
822
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
823
|
+
for i in range(len(device_indices_cpu)):
|
824
|
+
h_index = host_indices[i * self.page_size]
|
825
|
+
d_index = device_indices_cpu[i]
|
826
|
+
for j in range(self.layer_num):
|
827
|
+
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
828
|
+
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
829
|
+
non_blocking=True,
|
830
|
+
)
|
831
|
+
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
832
|
+
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
833
|
+
non_blocking=True,
|
834
|
+
)
|
835
|
+
|
836
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
837
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
838
|
+
for i in range(len(device_indices_cpu)):
|
839
|
+
h_index = host_indices[i * self.page_size]
|
840
|
+
d_index = device_indices_cpu[i]
|
841
|
+
device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
842
|
+
self.kv_buffer[0, layer_id, h_index : h_index + self.page_size],
|
843
|
+
non_blocking=True,
|
844
|
+
)
|
845
|
+
device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
846
|
+
self.kv_buffer[1, layer_id, h_index : h_index + self.page_size],
|
847
|
+
non_blocking=True,
|
848
|
+
)
|
849
|
+
|
808
850
|
|
809
851
|
class MLATokenToKVPoolHost(HostKVCache):
|
810
852
|
def __init__(
|
811
853
|
self,
|
812
854
|
device_pool: MLATokenToKVPool,
|
813
855
|
host_to_device_ratio: float,
|
814
|
-
|
856
|
+
page_size: int,
|
857
|
+
pin_memory: bool = True,
|
815
858
|
device: str = "cpu",
|
816
859
|
):
|
817
|
-
super().__init__(
|
860
|
+
super().__init__(
|
861
|
+
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
862
|
+
)
|
818
863
|
|
819
864
|
def get_size_per_token(self):
|
820
865
|
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
@@ -851,3 +896,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
851
896
|
|
852
897
|
def assign_flat_data(self, indices, flat_data):
|
853
898
|
self.kv_buffer[:, indices] = flat_data
|
899
|
+
|
900
|
+
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
901
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
902
|
+
for i in range(len(device_indices_cpu)):
|
903
|
+
h_index = host_indices[i * self.page_size]
|
904
|
+
d_index = device_indices_cpu[i]
|
905
|
+
for j in range(self.layer_num):
|
906
|
+
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
907
|
+
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
908
|
+
non_blocking=True,
|
909
|
+
)
|
910
|
+
|
911
|
+
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
912
|
+
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
913
|
+
for i in range(len(device_indices_cpu)):
|
914
|
+
h_index = host_indices[i * self.page_size]
|
915
|
+
d_index = device_indices_cpu[i]
|
916
|
+
device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_(
|
917
|
+
self.kv_buffer[layer_id, h_index : h_index + self.page_size],
|
918
|
+
non_blocking=True,
|
919
|
+
)
|
@@ -190,6 +190,30 @@ class PagedTokenToKVPoolAllocator:
|
|
190
190
|
def available_size(self):
|
191
191
|
return len(self.free_pages) * self.page_size
|
192
192
|
|
193
|
+
def get_kvcache(self):
|
194
|
+
return self._kvcache
|
195
|
+
|
196
|
+
def alloc(self, need_size: int):
|
197
|
+
# page-aligned allocation, returning contiguous indices of pages
|
198
|
+
if self.debug_mode:
|
199
|
+
assert (
|
200
|
+
need_size % self.page_size == 0
|
201
|
+
), "The allocation size should be page-aligned"
|
202
|
+
|
203
|
+
num_pages = need_size // self.page_size
|
204
|
+
if num_pages > len(self.free_pages):
|
205
|
+
return None
|
206
|
+
|
207
|
+
out_pages = self.free_pages[:num_pages]
|
208
|
+
self.free_pages = self.free_pages[num_pages:]
|
209
|
+
|
210
|
+
out_indices = (
|
211
|
+
out_pages[:, None] * self.page_size
|
212
|
+
+ torch.arange(self.page_size, device=self.device)
|
213
|
+
).reshape(-1)
|
214
|
+
|
215
|
+
return out_indices
|
216
|
+
|
193
217
|
def alloc_extend(
|
194
218
|
self,
|
195
219
|
prefix_lens: torch.Tensor,
|
@@ -218,6 +242,9 @@ class PagedTokenToKVPoolAllocator:
|
|
218
242
|
next_power_of_2(extend_num_tokens),
|
219
243
|
)
|
220
244
|
|
245
|
+
if self.debug_mode:
|
246
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
247
|
+
|
221
248
|
merged_value = self.ret_values.item()
|
222
249
|
num_new_pages = merged_value >> 32
|
223
250
|
if num_new_pages > len(self.free_pages):
|
@@ -248,6 +275,9 @@ class PagedTokenToKVPoolAllocator:
|
|
248
275
|
self.page_size,
|
249
276
|
)
|
250
277
|
|
278
|
+
if self.debug_mode:
|
279
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
280
|
+
|
251
281
|
num_new_pages = self.ret_values.item()
|
252
282
|
if num_new_pages > len(self.free_pages):
|
253
283
|
return None
|
@@ -265,6 +295,9 @@ class PagedTokenToKVPoolAllocator:
|
|
265
295
|
else:
|
266
296
|
self.free_group.append(free_index)
|
267
297
|
|
298
|
+
if self.debug_mode:
|
299
|
+
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
300
|
+
|
268
301
|
def free_group_begin(self):
|
269
302
|
self.is_not_in_free_group = False
|
270
303
|
self.free_group = []
|
@@ -274,6 +307,12 @@ class PagedTokenToKVPoolAllocator:
|
|
274
307
|
if self.free_group:
|
275
308
|
self.free(torch.cat(self.free_group))
|
276
309
|
|
310
|
+
def backup_state(self):
|
311
|
+
return self.free_pages
|
312
|
+
|
313
|
+
def restore_state(self, free_pages):
|
314
|
+
self.free_pages = free_pages
|
315
|
+
|
277
316
|
def clear(self):
|
278
317
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
279
318
|
self.free_pages = torch.arange(
|