sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +12 -6
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +43 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +36 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +75 -68
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +2 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +26 -10
- sglang/srt/models/gpt_oss.py +0 -14
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +14 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
|
|
93
93
|
# This updates radix so others can match
|
94
94
|
self.tree_cache.cache_unfinished_req(req)
|
95
95
|
|
96
|
-
if
|
96
|
+
if batch.return_logprob:
|
97
97
|
assert extend_logprob_start_len_per_req is not None
|
98
98
|
assert extend_input_len_per_req is not None
|
99
99
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
100
100
|
extend_input_len = extend_input_len_per_req[i]
|
101
101
|
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
102
|
+
if req.return_logprob:
|
103
|
+
self.add_logprob_return_values(
|
104
|
+
i,
|
105
|
+
req,
|
106
|
+
logprob_pt,
|
107
|
+
next_token_ids,
|
108
|
+
num_input_logprobs,
|
109
|
+
logits_output,
|
110
|
+
)
|
110
111
|
logprob_pt += num_input_logprobs
|
111
112
|
|
112
113
|
if (
|
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
|
|
146
147
|
skip_stream_req = req
|
147
148
|
|
148
149
|
# Incrementally update input logprobs.
|
149
|
-
if
|
150
|
+
if batch.return_logprob:
|
150
151
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
151
152
|
extend_input_len = extend_input_len_per_req[i]
|
152
153
|
if extend_logprob_start_len < extend_input_len:
|
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
|
|
154
155
|
num_input_logprobs = (
|
155
156
|
extend_input_len - extend_logprob_start_len
|
156
157
|
)
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
158
|
+
if req.return_logprob:
|
159
|
+
self.add_input_logprob_return_values(
|
160
|
+
i,
|
161
|
+
req,
|
162
|
+
logits_output,
|
163
|
+
logprob_pt,
|
164
|
+
num_input_logprobs,
|
165
|
+
last_prefill_chunk=False,
|
166
|
+
)
|
165
167
|
logprob_pt += num_input_logprobs
|
166
168
|
|
167
169
|
self.set_next_batch_sampling_info_done(batch)
|
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
|
|
121
121
|
url = params["url"]
|
122
122
|
|
123
123
|
worker = self.tp_worker.worker
|
124
|
-
|
125
124
|
worker.model_runner.save_remote_model(url)
|
126
125
|
|
126
|
+
if self.draft_worker is not None:
|
127
|
+
draft_url = params.get("draft_url", None)
|
128
|
+
assert (
|
129
|
+
draft_url is not None
|
130
|
+
), "draft_url must be provided when draft model is enabled"
|
131
|
+
draft_worker = self.draft_worker.worker
|
132
|
+
draft_worker.model_runner.save_remote_model(draft_url)
|
133
|
+
|
127
134
|
def save_sharded_model(self, params):
|
128
135
|
worker = self.tp_worker.worker
|
129
136
|
|
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
|
|
73
73
|
BatchTokenIDOut,
|
74
74
|
BatchTokenizedEmbeddingReqInput,
|
75
75
|
BatchTokenizedGenerateReqInput,
|
76
|
+
ClearHiCacheReqInput,
|
77
|
+
ClearHiCacheReqOutput,
|
76
78
|
CloseSessionReqInput,
|
77
79
|
ConfigureLoggingReq,
|
78
80
|
EmbeddingReqInput,
|
@@ -92,6 +94,7 @@ from sglang.srt.managers.io_struct import (
|
|
92
94
|
LoadLoRAAdapterReqInput,
|
93
95
|
LoadLoRAAdapterReqOutput,
|
94
96
|
LoRAUpdateResult,
|
97
|
+
MultiTokenizerWarpper,
|
95
98
|
OpenSessionReqInput,
|
96
99
|
OpenSessionReqOutput,
|
97
100
|
ProfileReq,
|
@@ -129,6 +132,7 @@ from sglang.srt.utils import (
|
|
129
132
|
dataclass_to_string_truncated,
|
130
133
|
freeze_gc,
|
131
134
|
get_bool_env_var,
|
135
|
+
get_origin_rid,
|
132
136
|
get_zmq_socket,
|
133
137
|
kill_process_tree,
|
134
138
|
)
|
@@ -264,9 +268,15 @@ class TokenizerManager:
|
|
264
268
|
self.recv_from_detokenizer = get_zmq_socket(
|
265
269
|
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
266
270
|
)
|
267
|
-
self.
|
268
|
-
|
269
|
-
|
271
|
+
if self.server_args.tokenizer_worker_num > 1:
|
272
|
+
# Use tokenizer_worker_ipc_name in multi-tokenizer mode
|
273
|
+
self.send_to_scheduler = get_zmq_socket(
|
274
|
+
context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
|
275
|
+
)
|
276
|
+
else:
|
277
|
+
self.send_to_scheduler = get_zmq_socket(
|
278
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
279
|
+
)
|
270
280
|
|
271
281
|
# Request states
|
272
282
|
self.no_create_loop = False
|
@@ -310,35 +320,7 @@ class TokenizerManager:
|
|
310
320
|
self.lora_update_lock = asyncio.Lock()
|
311
321
|
|
312
322
|
# For PD disaggregtion
|
313
|
-
self.
|
314
|
-
self.server_args.disaggregation_mode
|
315
|
-
)
|
316
|
-
self.disaggregation_transfer_backend = TransferBackend(
|
317
|
-
self.server_args.disaggregation_transfer_backend
|
318
|
-
)
|
319
|
-
# Start kv boostrap server on prefill
|
320
|
-
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
321
|
-
# only start bootstrap server on prefill tm
|
322
|
-
kv_bootstrap_server_class = get_kv_class(
|
323
|
-
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
324
|
-
)
|
325
|
-
self.bootstrap_server = kv_bootstrap_server_class(
|
326
|
-
self.server_args.disaggregation_bootstrap_port
|
327
|
-
)
|
328
|
-
is_create_store = (
|
329
|
-
self.server_args.node_rank == 0
|
330
|
-
and self.server_args.disaggregation_transfer_backend == "ascend"
|
331
|
-
)
|
332
|
-
if is_create_store:
|
333
|
-
try:
|
334
|
-
from mf_adapter import create_config_store
|
335
|
-
|
336
|
-
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
337
|
-
create_config_store(ascend_url)
|
338
|
-
except Exception as e:
|
339
|
-
error_message = f"Failed create mf store, invalid ascend_url."
|
340
|
-
error_message += f" With exception {e}"
|
341
|
-
raise error_message
|
323
|
+
self.init_disaggregation()
|
342
324
|
|
343
325
|
# For load balancing
|
344
326
|
self.current_load = 0
|
@@ -386,6 +368,9 @@ class TokenizerManager:
|
|
386
368
|
self.flush_cache_communicator = _Communicator(
|
387
369
|
self.send_to_scheduler, server_args.dp_size
|
388
370
|
)
|
371
|
+
self.clear_hicache_storage_communicator = _Communicator(
|
372
|
+
self.send_to_scheduler, server_args.dp_size
|
373
|
+
)
|
389
374
|
self.profile_communicator = _Communicator(
|
390
375
|
self.send_to_scheduler, server_args.dp_size
|
391
376
|
)
|
@@ -447,6 +432,10 @@ class TokenizerManager:
|
|
447
432
|
SlowDownReqOutput,
|
448
433
|
self.slow_down_communicator.handle_recv,
|
449
434
|
),
|
435
|
+
(
|
436
|
+
ClearHiCacheReqOutput,
|
437
|
+
self.clear_hicache_storage_communicator.handle_recv,
|
438
|
+
),
|
450
439
|
(
|
451
440
|
FlushCacheReqOutput,
|
452
441
|
self.flush_cache_communicator.handle_recv,
|
@@ -479,6 +468,37 @@ class TokenizerManager:
|
|
479
468
|
]
|
480
469
|
)
|
481
470
|
|
471
|
+
def init_disaggregation(self):
|
472
|
+
self.disaggregation_mode = DisaggregationMode(
|
473
|
+
self.server_args.disaggregation_mode
|
474
|
+
)
|
475
|
+
self.disaggregation_transfer_backend = TransferBackend(
|
476
|
+
self.server_args.disaggregation_transfer_backend
|
477
|
+
)
|
478
|
+
# Start kv boostrap server on prefill
|
479
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
480
|
+
# only start bootstrap server on prefill tm
|
481
|
+
kv_bootstrap_server_class = get_kv_class(
|
482
|
+
self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
|
483
|
+
)
|
484
|
+
self.bootstrap_server = kv_bootstrap_server_class(
|
485
|
+
self.server_args.disaggregation_bootstrap_port
|
486
|
+
)
|
487
|
+
is_create_store = (
|
488
|
+
self.server_args.node_rank == 0
|
489
|
+
and self.server_args.disaggregation_transfer_backend == "ascend"
|
490
|
+
)
|
491
|
+
if is_create_store:
|
492
|
+
try:
|
493
|
+
from mf_adapter import create_config_store
|
494
|
+
|
495
|
+
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
496
|
+
create_config_store(ascend_url)
|
497
|
+
except Exception as e:
|
498
|
+
error_message = f"Failed create mf store, invalid ascend_url."
|
499
|
+
error_message += f" With exception {e}"
|
500
|
+
raise error_message
|
501
|
+
|
482
502
|
async def generate_request(
|
483
503
|
self,
|
484
504
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -488,6 +508,15 @@ class TokenizerManager:
|
|
488
508
|
self.auto_create_handle_loop()
|
489
509
|
obj.normalize_batch_and_arguments()
|
490
510
|
|
511
|
+
if self.server_args.tokenizer_worker_num > 1:
|
512
|
+
# Modify rid, add worker_id
|
513
|
+
if isinstance(obj.rid, list):
|
514
|
+
# If it's an array, add worker_id prefix to each element
|
515
|
+
obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
|
516
|
+
else:
|
517
|
+
# If it's a single value, add worker_id prefix
|
518
|
+
obj.rid = f"{self.worker_id}_{obj.rid}"
|
519
|
+
|
491
520
|
if self.log_requests:
|
492
521
|
max_length, skip_names, _ = self.log_request_metadata
|
493
522
|
logger.info(
|
@@ -988,6 +1017,13 @@ class TokenizerManager:
|
|
988
1017
|
async def flush_cache(self) -> FlushCacheReqOutput:
|
989
1018
|
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
|
990
1019
|
|
1020
|
+
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
|
1021
|
+
"""Clear the hierarchical cache storage."""
|
1022
|
+
# Delegate to the scheduler to handle HiCacheStorage clearing
|
1023
|
+
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
|
1024
|
+
0
|
1025
|
+
]
|
1026
|
+
|
991
1027
|
def abort_request(self, rid: str = "", abort_all: bool = False):
|
992
1028
|
if not abort_all and rid not in self.rid_to_state:
|
993
1029
|
return
|
@@ -1080,6 +1116,8 @@ class TokenizerManager:
|
|
1080
1116
|
async def _wait_for_model_update_from_disk(
|
1081
1117
|
self, obj: UpdateWeightFromDiskReqInput
|
1082
1118
|
) -> Tuple[bool, str]:
|
1119
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1120
|
+
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1083
1121
|
self.send_to_scheduler.send_pyobj(obj)
|
1084
1122
|
self.model_update_result = asyncio.Future()
|
1085
1123
|
if self.server_args.dp_size == 1:
|
@@ -1299,6 +1337,8 @@ class TokenizerManager:
|
|
1299
1337
|
elif obj.session_id in self.session_futures:
|
1300
1338
|
return None
|
1301
1339
|
|
1340
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1341
|
+
obj = MultiTokenizerWarpper(self.worker_id, obj)
|
1302
1342
|
self.send_to_scheduler.send_pyobj(obj)
|
1303
1343
|
|
1304
1344
|
self.session_futures[obj.session_id] = asyncio.Future()
|
@@ -1319,13 +1359,11 @@ class TokenizerManager:
|
|
1319
1359
|
# Many DP ranks
|
1320
1360
|
return [res.internal_state for res in responses]
|
1321
1361
|
|
1322
|
-
async def set_internal_state(
|
1323
|
-
self, obj: SetInternalStateReq
|
1324
|
-
) -> SetInternalStateReqOutput:
|
1362
|
+
async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
|
1325
1363
|
responses: List[SetInternalStateReqOutput] = (
|
1326
1364
|
await self.set_internal_state_communicator(obj)
|
1327
1365
|
)
|
1328
|
-
return [res.
|
1366
|
+
return [res.updated for res in responses]
|
1329
1367
|
|
1330
1368
|
async def get_load(self) -> dict:
|
1331
1369
|
# TODO(lsyin): fake load report server
|
@@ -1576,7 +1614,6 @@ class TokenizerManager:
|
|
1576
1614
|
|
1577
1615
|
async def handle_loop(self):
|
1578
1616
|
"""The event loop that handles requests"""
|
1579
|
-
|
1580
1617
|
while True:
|
1581
1618
|
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
1582
1619
|
self._result_dispatcher(recv_obj)
|
@@ -1596,9 +1633,12 @@ class TokenizerManager:
|
|
1596
1633
|
)
|
1597
1634
|
continue
|
1598
1635
|
|
1636
|
+
origin_rid = rid
|
1637
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1638
|
+
origin_rid = get_origin_rid(rid)
|
1599
1639
|
# Build meta_info and return value
|
1600
1640
|
meta_info = {
|
1601
|
-
"id":
|
1641
|
+
"id": origin_rid,
|
1602
1642
|
"finish_reason": recv_obj.finished_reasons[i],
|
1603
1643
|
"prompt_tokens": recv_obj.prompt_tokens[i],
|
1604
1644
|
"weight_version": self.server_args.weight_version,
|
@@ -1904,6 +1944,9 @@ class TokenizerManager:
|
|
1904
1944
|
if is_health_check_generate_req(recv_obj):
|
1905
1945
|
return
|
1906
1946
|
state = self.rid_to_state[recv_obj.rid]
|
1947
|
+
origin_rid = recv_obj.rid
|
1948
|
+
if self.server_args.tokenizer_worker_num > 1:
|
1949
|
+
origin_rid = get_origin_rid(origin_rid)
|
1907
1950
|
state.finished = True
|
1908
1951
|
if recv_obj.finished_reason:
|
1909
1952
|
out = {
|
@@ -1916,7 +1959,7 @@ class TokenizerManager:
|
|
1916
1959
|
out = {
|
1917
1960
|
"text": "",
|
1918
1961
|
"meta_info": {
|
1919
|
-
"id":
|
1962
|
+
"id": origin_rid,
|
1920
1963
|
"finish_reason": {
|
1921
1964
|
"type": "abort",
|
1922
1965
|
"message": "Abort before prefill",
|
@@ -2102,6 +2145,8 @@ T = TypeVar("T")
|
|
2102
2145
|
class _Communicator(Generic[T]):
|
2103
2146
|
"""Note: The communicator now only run up to 1 in-flight request at any time."""
|
2104
2147
|
|
2148
|
+
enable_multi_tokenizer = False
|
2149
|
+
|
2105
2150
|
def __init__(self, sender, fan_out: int):
|
2106
2151
|
self._sender = sender
|
2107
2152
|
self._fan_out = fan_out
|
@@ -2118,6 +2163,8 @@ class _Communicator(Generic[T]):
|
|
2118
2163
|
assert self._result_values is None
|
2119
2164
|
|
2120
2165
|
if obj:
|
2166
|
+
if _Communicator.enable_multi_tokenizer:
|
2167
|
+
obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
|
2121
2168
|
self._sender.send_pyobj(obj)
|
2122
2169
|
|
2123
2170
|
self._result_event = asyncio.Event()
|
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
|
|
47
47
|
self.req_to_token_pool.free(req.req_pool_idx)
|
48
48
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
49
49
|
|
50
|
-
def cache_unfinished_req(self, req: Req):
|
50
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
51
51
|
kv_indices = self.req_to_token_pool.req_to_token[
|
52
52
|
req.req_pool_idx, : len(req.fill_ids)
|
53
53
|
]
|
@@ -102,6 +102,20 @@ class HiCacheStorage(ABC):
|
|
102
102
|
"""
|
103
103
|
pass
|
104
104
|
|
105
|
+
@abstractmethod
|
106
|
+
def delete(self, key: str) -> bool:
|
107
|
+
"""
|
108
|
+
Delete the entry associated with the given key.
|
109
|
+
"""
|
110
|
+
pass
|
111
|
+
|
112
|
+
@abstractmethod
|
113
|
+
def clear(self) -> bool:
|
114
|
+
"""
|
115
|
+
Clear all entries in the storage.
|
116
|
+
"""
|
117
|
+
pass
|
118
|
+
|
105
119
|
def batch_exists(self, keys: List[str]) -> int:
|
106
120
|
"""
|
107
121
|
Check if the keys exist in the storage.
|
@@ -175,11 +189,12 @@ class HiCacheFile(HiCacheStorage):
|
|
175
189
|
target_location: Optional[Any] = None,
|
176
190
|
target_sizes: Optional[Any] = None,
|
177
191
|
) -> bool:
|
178
|
-
key = self._get_suffixed_key(key)
|
179
|
-
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
180
192
|
if self.exists(key):
|
181
193
|
logger.debug(f"Key {key} already exists. Skipped.")
|
182
194
|
return True
|
195
|
+
|
196
|
+
key = self._get_suffixed_key(key)
|
197
|
+
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
183
198
|
try:
|
184
199
|
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
185
200
|
return True
|
@@ -213,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
|
|
213
228
|
logger.warning(f"Key {key} does not exist. Cannot delete.")
|
214
229
|
return
|
215
230
|
|
216
|
-
def clear(self) ->
|
231
|
+
def clear(self) -> bool:
|
217
232
|
try:
|
218
233
|
for filename in os.listdir(self.file_path):
|
219
234
|
file_path = os.path.join(self.file_path, filename)
|
220
235
|
if os.path.isfile(file_path):
|
221
236
|
os.remove(file_path)
|
222
237
|
logger.info("Cleared all entries in HiCacheFile storage.")
|
238
|
+
return True
|
223
239
|
except Exception as e:
|
224
240
|
logger.error(f"Failed to clear HiCacheFile storage: {e}")
|
241
|
+
return False
|
@@ -102,10 +102,7 @@ class HiRadixCache(RadixCache):
|
|
102
102
|
self.ongoing_backup = {}
|
103
103
|
# todo: dynamically adjust the threshold
|
104
104
|
self.write_through_threshold = (
|
105
|
-
1 if hicache_write_policy == "write_through" else
|
106
|
-
)
|
107
|
-
self.write_through_threshold_storage = (
|
108
|
-
1 if hicache_write_policy == "write_through" else 3
|
105
|
+
1 if hicache_write_policy == "write_through" else 2
|
109
106
|
)
|
110
107
|
self.load_back_threshold = 10
|
111
108
|
super().__init__(
|
@@ -125,6 +122,15 @@ class HiRadixCache(RadixCache):
|
|
125
122
|
height += 1
|
126
123
|
return height
|
127
124
|
|
125
|
+
def clear_storage_backend(self):
|
126
|
+
if self.enable_storage:
|
127
|
+
self.cache_controller.storage_backend.clear()
|
128
|
+
logger.info("Hierarchical cache storage backend cleared successfully!")
|
129
|
+
return True
|
130
|
+
else:
|
131
|
+
logger.warning("Hierarchical cache storage backend is not enabled.")
|
132
|
+
return False
|
133
|
+
|
128
134
|
def write_backup(self, node: TreeNode, write_back=False):
|
129
135
|
host_indices = self.cache_controller.write(
|
130
136
|
device_indices=node.value,
|
@@ -155,8 +161,9 @@ class HiRadixCache(RadixCache):
|
|
155
161
|
self.ongoing_backup[operation_id] = node
|
156
162
|
node.protect_host()
|
157
163
|
|
158
|
-
def
|
159
|
-
|
164
|
+
def _inc_hit_count(self, node: TreeNode, chunked=False):
|
165
|
+
# skip the hit count update for chunked requests
|
166
|
+
if self.cache_controller.write_policy == "write_back" or chunked:
|
160
167
|
return
|
161
168
|
node.hit_count += 1
|
162
169
|
|
@@ -164,14 +171,6 @@ class HiRadixCache(RadixCache):
|
|
164
171
|
if node.hit_count >= self.write_through_threshold:
|
165
172
|
# write to host if the node is not backuped
|
166
173
|
self.write_backup(node)
|
167
|
-
else:
|
168
|
-
if (
|
169
|
-
self.enable_storage
|
170
|
-
and (not node.backuped_storage)
|
171
|
-
and node.hit_count >= self.write_through_threshold_storage
|
172
|
-
):
|
173
|
-
# if the node is backuped on host memory but not on storage
|
174
|
-
self.write_backup_storage(node)
|
175
174
|
|
176
175
|
def writing_check(self, write_back=False):
|
177
176
|
if write_back:
|
@@ -192,8 +191,11 @@ class HiRadixCache(RadixCache):
|
|
192
191
|
)
|
193
192
|
for _ in range(queue_size.item()):
|
194
193
|
ack_id = self.cache_controller.ack_write_queue.get()
|
195
|
-
self.
|
194
|
+
backuped_node = self.ongoing_write_through[ack_id]
|
195
|
+
self.dec_lock_ref(backuped_node)
|
196
196
|
del self.ongoing_write_through[ack_id]
|
197
|
+
if self.enable_storage:
|
198
|
+
self.write_backup_storage(backuped_node)
|
197
199
|
|
198
200
|
def loading_check(self):
|
199
201
|
while not self.cache_controller.ack_load_queue.empty():
|
@@ -376,57 +378,54 @@ class HiRadixCache(RadixCache):
|
|
376
378
|
self.writing_check()
|
377
379
|
self.loading_check()
|
378
380
|
if self.enable_storage:
|
379
|
-
self.
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
381
|
+
self.drain_storage_control_queues()
|
382
|
+
|
383
|
+
def drain_storage_control_queues(self):
|
384
|
+
"""
|
385
|
+
Combine prefetch revoke, backup ack, and host mem release checks
|
386
|
+
to minimize TP synchronization and Python overhead.
|
387
|
+
"""
|
388
|
+
cc = self.cache_controller
|
389
|
+
|
390
|
+
qsizes = torch.tensor(
|
391
|
+
[
|
392
|
+
cc.prefetch_revoke_queue.qsize(),
|
393
|
+
cc.ack_backup_queue.qsize(),
|
394
|
+
cc.host_mem_release_queue.qsize(),
|
395
|
+
],
|
396
|
+
dtype=torch.int,
|
385
397
|
)
|
386
398
|
if self.tp_world_size > 1:
|
387
|
-
# synchrnoize TP workers to make the same update to hiradix cache
|
388
399
|
torch.distributed.all_reduce(
|
389
|
-
|
390
|
-
op=torch.distributed.ReduceOp.MIN,
|
391
|
-
group=self.tp_group,
|
400
|
+
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
392
401
|
)
|
393
|
-
for _ in range(queue_size.item()):
|
394
|
-
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
395
|
-
if req_id in self.ongoing_prefetch:
|
396
|
-
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
397
|
-
last_host_node.release_host()
|
398
|
-
del self.ongoing_prefetch[req_id]
|
399
|
-
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
400
|
-
else:
|
401
|
-
# the revoked operation already got terminated
|
402
|
-
pass
|
403
402
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
)
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
403
|
+
n_revoke, n_backup, n_release = map(int, qsizes.tolist())
|
404
|
+
|
405
|
+
# process prefetch revokes
|
406
|
+
for _ in range(n_revoke):
|
407
|
+
req_id = cc.prefetch_revoke_queue.get()
|
408
|
+
info = self.ongoing_prefetch.pop(req_id, None)
|
409
|
+
if info is not None:
|
410
|
+
last_host_node, token_ids, _, _ = info
|
411
|
+
last_host_node.release_host()
|
412
|
+
cc.prefetch_tokens_occupied -= len(token_ids)
|
413
|
+
# else: the revoked operation already got terminated, nothing to do
|
414
|
+
|
415
|
+
# process backup acks
|
416
|
+
for _ in range(n_backup):
|
417
|
+
ack_id = cc.ack_backup_queue.get()
|
418
|
+
entry = self.ongoing_backup.pop(ack_id, None)
|
419
|
+
if entry is not None:
|
420
|
+
entry.release_host()
|
421
|
+
|
422
|
+
# release host memory
|
423
|
+
host_indices_list = []
|
424
|
+
for _ in range(n_release):
|
425
|
+
host_indices_list.append(cc.host_mem_release_queue.get())
|
426
|
+
if host_indices_list:
|
427
|
+
host_indices = torch.cat(host_indices_list, dim=0)
|
428
|
+
cc.mem_pool_host.free(host_indices)
|
430
429
|
|
431
430
|
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
432
431
|
can_terminate = True
|
@@ -509,7 +508,7 @@ class HiRadixCache(RadixCache):
|
|
509
508
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
510
509
|
|
511
510
|
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
|
512
|
-
self.cache_controller.
|
511
|
+
self.cache_controller.append_host_mem_release(
|
513
512
|
host_indices[min_completed_tokens:completed_tokens]
|
514
513
|
)
|
515
514
|
last_host_node.release_host()
|
@@ -565,7 +564,11 @@ class HiRadixCache(RadixCache):
|
|
565
564
|
len(new_input_tokens) % self.page_size
|
566
565
|
)
|
567
566
|
new_input_tokens = new_input_tokens[:prefetch_length]
|
568
|
-
if
|
567
|
+
if (
|
568
|
+
not self.enable_storage
|
569
|
+
or prefetch_length < self.prefetch_threshold
|
570
|
+
or self.cache_controller.prefetch_rate_limited()
|
571
|
+
):
|
569
572
|
return
|
570
573
|
|
571
574
|
last_host_node.protect_host()
|
@@ -573,6 +576,10 @@ class HiRadixCache(RadixCache):
|
|
573
576
|
if host_indices is None:
|
574
577
|
self.evict_host(prefetch_length)
|
575
578
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
579
|
+
if host_indices is None:
|
580
|
+
last_host_node.release_host()
|
581
|
+
# no sufficient host memory for prefetch
|
582
|
+
return
|
576
583
|
operation = self.cache_controller.prefetch(
|
577
584
|
req_id, host_indices, new_input_tokens, last_hash
|
578
585
|
)
|
@@ -672,11 +679,11 @@ class HiRadixCache(RadixCache):
|
|
672
679
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
673
680
|
return new_node
|
674
681
|
|
675
|
-
def
|
676
|
-
node.last_access_time = time.monotonic()
|
682
|
+
def insert(self, key: List, value, chunked=False):
|
677
683
|
if len(key) == 0:
|
678
684
|
return 0
|
679
685
|
|
686
|
+
node = self.root_node
|
680
687
|
child_key = self.get_child_key_fn(key)
|
681
688
|
total_prefix_length = 0
|
682
689
|
|
@@ -693,7 +700,7 @@ class HiRadixCache(RadixCache):
|
|
693
700
|
self.token_to_kv_pool_host.update_synced(node.host_value)
|
694
701
|
self.evictable_size_ += len(node.value)
|
695
702
|
else:
|
696
|
-
self.
|
703
|
+
self._inc_hit_count(node, chunked)
|
697
704
|
total_prefix_length += prefix_len
|
698
705
|
else:
|
699
706
|
# partial match, split the node
|
@@ -703,7 +710,7 @@ class HiRadixCache(RadixCache):
|
|
703
710
|
self.token_to_kv_pool_host.update_synced(new_node.host_value)
|
704
711
|
self.evictable_size_ += len(new_node.value)
|
705
712
|
else:
|
706
|
-
self.
|
713
|
+
self._inc_hit_count(new_node, chunked)
|
707
714
|
total_prefix_length += prefix_len
|
708
715
|
node = new_node
|
709
716
|
|
@@ -737,7 +744,7 @@ class HiRadixCache(RadixCache):
|
|
737
744
|
last_hash = new_node.hash_value[-1]
|
738
745
|
|
739
746
|
if self.cache_controller.write_policy != "write_back":
|
740
|
-
self.
|
747
|
+
self._inc_hit_count(new_node, chunked)
|
741
748
|
return total_prefix_length
|
742
749
|
|
743
750
|
def _collect_leaves_device(self):
|
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
|
|
183
183
|
self.req_to_token_pool.free(req.req_pool_idx)
|
184
184
|
self.dec_lock_ref(req.last_node)
|
185
185
|
|
186
|
-
def cache_unfinished_req(self, req: Req):
|
186
|
+
def cache_unfinished_req(self, req: Req, chunked=False):
|
187
187
|
"""Cache request when it is unfinished."""
|
188
188
|
if self.disable:
|
189
189
|
return
|
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
918
918
|
layer_num,
|
919
919
|
self.size // self.page_size + 1,
|
920
920
|
self.page_size,
|
921
|
+
1,
|
921
922
|
self.kv_lora_rank,
|
922
923
|
),
|
923
924
|
dtype=self.store_dtype,
|
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
928
929
|
layer_num,
|
929
930
|
self.size // self.page_size + 1,
|
930
931
|
self.page_size,
|
932
|
+
1,
|
931
933
|
self.qk_rope_head_dim,
|
932
934
|
),
|
933
935
|
dtype=self.store_dtype,
|
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
1000
1002
|
layer_id = layer.layer_id
|
1001
1003
|
if cache_k.dtype != self.dtype:
|
1002
1004
|
cache_k = cache_k.to(self.dtype)
|
1005
|
+
cache_v = cache_v.to(self.dtype)
|
1003
1006
|
|
1004
1007
|
if self.store_dtype != self.dtype:
|
1005
1008
|
cache_k = cache_k.view(self.store_dtype)
|
1009
|
+
cache_v = cache_v.view(self.store_dtype)
|
1006
1010
|
|
1007
1011
|
if cache_v is None:
|
1008
1012
|
cache_k, cache_v = cache_k.split(
|