sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +12 -0
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +24 -14
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  29. sglang/srt/layers/moe/ep_moe/layer.py +12 -6
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/topk.py +35 -12
  32. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  33. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  34. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  35. sglang/srt/layers/quantization/mxfp4.py +9 -4
  36. sglang/srt/layers/quantization/utils.py +13 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  38. sglang/srt/layers/rotary_embedding.py +28 -1
  39. sglang/srt/layers/sampler.py +29 -5
  40. sglang/srt/managers/cache_controller.py +62 -96
  41. sglang/srt/managers/detokenizer_manager.py +43 -2
  42. sglang/srt/managers/io_struct.py +27 -0
  43. sglang/srt/managers/mm_utils.py +5 -1
  44. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  45. sglang/srt/managers/scheduler.py +36 -2
  46. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  47. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  48. sglang/srt/managers/tokenizer_manager.py +86 -39
  49. sglang/srt/mem_cache/chunk_cache.py +1 -1
  50. sglang/srt/mem_cache/hicache_storage.py +20 -3
  51. sglang/srt/mem_cache/hiradix_cache.py +75 -68
  52. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  53. sglang/srt/mem_cache/memory_pool.py +4 -0
  54. sglang/srt/mem_cache/memory_pool_host.py +2 -4
  55. sglang/srt/mem_cache/radix_cache.py +5 -4
  56. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  57. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
  58. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  59. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  60. sglang/srt/model_executor/model_runner.py +5 -4
  61. sglang/srt/model_loader/loader.py +15 -24
  62. sglang/srt/model_loader/utils.py +12 -0
  63. sglang/srt/models/deepseek_v2.py +26 -10
  64. sglang/srt/models/gpt_oss.py +0 -14
  65. sglang/srt/models/llama_eagle3.py +4 -0
  66. sglang/srt/models/longcat_flash.py +1015 -0
  67. sglang/srt/models/longcat_flash_nextn.py +691 -0
  68. sglang/srt/models/qwen2.py +26 -3
  69. sglang/srt/models/qwen2_5_vl.py +65 -41
  70. sglang/srt/models/qwen2_moe.py +22 -2
  71. sglang/srt/models/transformers.py +1 -1
  72. sglang/srt/multimodal/processors/base_processor.py +4 -2
  73. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  74. sglang/srt/server_args.py +112 -55
  75. sglang/srt/speculative/eagle_worker.py +28 -8
  76. sglang/srt/utils.py +14 -0
  77. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  78. sglang/version.py +1 -1
  79. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
  80. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
  81. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  82. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  83. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -93,20 +93,21 @@ class SchedulerOutputProcessorMixin:
93
93
  # This updates radix so others can match
94
94
  self.tree_cache.cache_unfinished_req(req)
95
95
 
96
- if req.return_logprob:
96
+ if batch.return_logprob:
97
97
  assert extend_logprob_start_len_per_req is not None
98
98
  assert extend_input_len_per_req is not None
99
99
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
100
100
  extend_input_len = extend_input_len_per_req[i]
101
101
  num_input_logprobs = extend_input_len - extend_logprob_start_len
102
- self.add_logprob_return_values(
103
- i,
104
- req,
105
- logprob_pt,
106
- next_token_ids,
107
- num_input_logprobs,
108
- logits_output,
109
- )
102
+ if req.return_logprob:
103
+ self.add_logprob_return_values(
104
+ i,
105
+ req,
106
+ logprob_pt,
107
+ next_token_ids,
108
+ num_input_logprobs,
109
+ logits_output,
110
+ )
110
111
  logprob_pt += num_input_logprobs
111
112
 
112
113
  if (
@@ -146,7 +147,7 @@ class SchedulerOutputProcessorMixin:
146
147
  skip_stream_req = req
147
148
 
148
149
  # Incrementally update input logprobs.
149
- if req.return_logprob:
150
+ if batch.return_logprob:
150
151
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
151
152
  extend_input_len = extend_input_len_per_req[i]
152
153
  if extend_logprob_start_len < extend_input_len:
@@ -154,14 +155,15 @@ class SchedulerOutputProcessorMixin:
154
155
  num_input_logprobs = (
155
156
  extend_input_len - extend_logprob_start_len
156
157
  )
157
- self.add_input_logprob_return_values(
158
- i,
159
- req,
160
- logits_output,
161
- logprob_pt,
162
- num_input_logprobs,
163
- last_prefill_chunk=False,
164
- )
158
+ if req.return_logprob:
159
+ self.add_input_logprob_return_values(
160
+ i,
161
+ req,
162
+ logits_output,
163
+ logprob_pt,
164
+ num_input_logprobs,
165
+ last_prefill_chunk=False,
166
+ )
165
167
  logprob_pt += num_input_logprobs
166
168
 
167
169
  self.set_next_batch_sampling_info_done(batch)
@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
121
121
  url = params["url"]
122
122
 
123
123
  worker = self.tp_worker.worker
124
-
125
124
  worker.model_runner.save_remote_model(url)
126
125
 
126
+ if self.draft_worker is not None:
127
+ draft_url = params.get("draft_url", None)
128
+ assert (
129
+ draft_url is not None
130
+ ), "draft_url must be provided when draft model is enabled"
131
+ draft_worker = self.draft_worker.worker
132
+ draft_worker.model_runner.save_remote_model(draft_url)
133
+
127
134
  def save_sharded_model(self, params):
128
135
  worker = self.tp_worker.worker
129
136
 
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
73
73
  BatchTokenIDOut,
74
74
  BatchTokenizedEmbeddingReqInput,
75
75
  BatchTokenizedGenerateReqInput,
76
+ ClearHiCacheReqInput,
77
+ ClearHiCacheReqOutput,
76
78
  CloseSessionReqInput,
77
79
  ConfigureLoggingReq,
78
80
  EmbeddingReqInput,
@@ -92,6 +94,7 @@ from sglang.srt.managers.io_struct import (
92
94
  LoadLoRAAdapterReqInput,
93
95
  LoadLoRAAdapterReqOutput,
94
96
  LoRAUpdateResult,
97
+ MultiTokenizerWarpper,
95
98
  OpenSessionReqInput,
96
99
  OpenSessionReqOutput,
97
100
  ProfileReq,
@@ -129,6 +132,7 @@ from sglang.srt.utils import (
129
132
  dataclass_to_string_truncated,
130
133
  freeze_gc,
131
134
  get_bool_env_var,
135
+ get_origin_rid,
132
136
  get_zmq_socket,
133
137
  kill_process_tree,
134
138
  )
@@ -264,9 +268,15 @@ class TokenizerManager:
264
268
  self.recv_from_detokenizer = get_zmq_socket(
265
269
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
266
270
  )
267
- self.send_to_scheduler = get_zmq_socket(
268
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
269
- )
271
+ if self.server_args.tokenizer_worker_num > 1:
272
+ # Use tokenizer_worker_ipc_name in multi-tokenizer mode
273
+ self.send_to_scheduler = get_zmq_socket(
274
+ context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
275
+ )
276
+ else:
277
+ self.send_to_scheduler = get_zmq_socket(
278
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
279
+ )
270
280
 
271
281
  # Request states
272
282
  self.no_create_loop = False
@@ -310,35 +320,7 @@ class TokenizerManager:
310
320
  self.lora_update_lock = asyncio.Lock()
311
321
 
312
322
  # For PD disaggregtion
313
- self.disaggregation_mode = DisaggregationMode(
314
- self.server_args.disaggregation_mode
315
- )
316
- self.disaggregation_transfer_backend = TransferBackend(
317
- self.server_args.disaggregation_transfer_backend
318
- )
319
- # Start kv boostrap server on prefill
320
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
321
- # only start bootstrap server on prefill tm
322
- kv_bootstrap_server_class = get_kv_class(
323
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
324
- )
325
- self.bootstrap_server = kv_bootstrap_server_class(
326
- self.server_args.disaggregation_bootstrap_port
327
- )
328
- is_create_store = (
329
- self.server_args.node_rank == 0
330
- and self.server_args.disaggregation_transfer_backend == "ascend"
331
- )
332
- if is_create_store:
333
- try:
334
- from mf_adapter import create_config_store
335
-
336
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
337
- create_config_store(ascend_url)
338
- except Exception as e:
339
- error_message = f"Failed create mf store, invalid ascend_url."
340
- error_message += f" With exception {e}"
341
- raise error_message
323
+ self.init_disaggregation()
342
324
 
343
325
  # For load balancing
344
326
  self.current_load = 0
@@ -386,6 +368,9 @@ class TokenizerManager:
386
368
  self.flush_cache_communicator = _Communicator(
387
369
  self.send_to_scheduler, server_args.dp_size
388
370
  )
371
+ self.clear_hicache_storage_communicator = _Communicator(
372
+ self.send_to_scheduler, server_args.dp_size
373
+ )
389
374
  self.profile_communicator = _Communicator(
390
375
  self.send_to_scheduler, server_args.dp_size
391
376
  )
@@ -447,6 +432,10 @@ class TokenizerManager:
447
432
  SlowDownReqOutput,
448
433
  self.slow_down_communicator.handle_recv,
449
434
  ),
435
+ (
436
+ ClearHiCacheReqOutput,
437
+ self.clear_hicache_storage_communicator.handle_recv,
438
+ ),
450
439
  (
451
440
  FlushCacheReqOutput,
452
441
  self.flush_cache_communicator.handle_recv,
@@ -479,6 +468,37 @@ class TokenizerManager:
479
468
  ]
480
469
  )
481
470
 
471
+ def init_disaggregation(self):
472
+ self.disaggregation_mode = DisaggregationMode(
473
+ self.server_args.disaggregation_mode
474
+ )
475
+ self.disaggregation_transfer_backend = TransferBackend(
476
+ self.server_args.disaggregation_transfer_backend
477
+ )
478
+ # Start kv boostrap server on prefill
479
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
480
+ # only start bootstrap server on prefill tm
481
+ kv_bootstrap_server_class = get_kv_class(
482
+ self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
483
+ )
484
+ self.bootstrap_server = kv_bootstrap_server_class(
485
+ self.server_args.disaggregation_bootstrap_port
486
+ )
487
+ is_create_store = (
488
+ self.server_args.node_rank == 0
489
+ and self.server_args.disaggregation_transfer_backend == "ascend"
490
+ )
491
+ if is_create_store:
492
+ try:
493
+ from mf_adapter import create_config_store
494
+
495
+ ascend_url = os.getenv("ASCEND_MF_STORE_URL")
496
+ create_config_store(ascend_url)
497
+ except Exception as e:
498
+ error_message = f"Failed create mf store, invalid ascend_url."
499
+ error_message += f" With exception {e}"
500
+ raise error_message
501
+
482
502
  async def generate_request(
483
503
  self,
484
504
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -488,6 +508,15 @@ class TokenizerManager:
488
508
  self.auto_create_handle_loop()
489
509
  obj.normalize_batch_and_arguments()
490
510
 
511
+ if self.server_args.tokenizer_worker_num > 1:
512
+ # Modify rid, add worker_id
513
+ if isinstance(obj.rid, list):
514
+ # If it's an array, add worker_id prefix to each element
515
+ obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
516
+ else:
517
+ # If it's a single value, add worker_id prefix
518
+ obj.rid = f"{self.worker_id}_{obj.rid}"
519
+
491
520
  if self.log_requests:
492
521
  max_length, skip_names, _ = self.log_request_metadata
493
522
  logger.info(
@@ -988,6 +1017,13 @@ class TokenizerManager:
988
1017
  async def flush_cache(self) -> FlushCacheReqOutput:
989
1018
  return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
990
1019
 
1020
+ async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1021
+ """Clear the hierarchical cache storage."""
1022
+ # Delegate to the scheduler to handle HiCacheStorage clearing
1023
+ return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1024
+ 0
1025
+ ]
1026
+
991
1027
  def abort_request(self, rid: str = "", abort_all: bool = False):
992
1028
  if not abort_all and rid not in self.rid_to_state:
993
1029
  return
@@ -1080,6 +1116,8 @@ class TokenizerManager:
1080
1116
  async def _wait_for_model_update_from_disk(
1081
1117
  self, obj: UpdateWeightFromDiskReqInput
1082
1118
  ) -> Tuple[bool, str]:
1119
+ if self.server_args.tokenizer_worker_num > 1:
1120
+ obj = MultiTokenizerWarpper(self.worker_id, obj)
1083
1121
  self.send_to_scheduler.send_pyobj(obj)
1084
1122
  self.model_update_result = asyncio.Future()
1085
1123
  if self.server_args.dp_size == 1:
@@ -1299,6 +1337,8 @@ class TokenizerManager:
1299
1337
  elif obj.session_id in self.session_futures:
1300
1338
  return None
1301
1339
 
1340
+ if self.server_args.tokenizer_worker_num > 1:
1341
+ obj = MultiTokenizerWarpper(self.worker_id, obj)
1302
1342
  self.send_to_scheduler.send_pyobj(obj)
1303
1343
 
1304
1344
  self.session_futures[obj.session_id] = asyncio.Future()
@@ -1319,13 +1359,11 @@ class TokenizerManager:
1319
1359
  # Many DP ranks
1320
1360
  return [res.internal_state for res in responses]
1321
1361
 
1322
- async def set_internal_state(
1323
- self, obj: SetInternalStateReq
1324
- ) -> SetInternalStateReqOutput:
1362
+ async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1325
1363
  responses: List[SetInternalStateReqOutput] = (
1326
1364
  await self.set_internal_state_communicator(obj)
1327
1365
  )
1328
- return [res.internal_state for res in responses]
1366
+ return [res.updated for res in responses]
1329
1367
 
1330
1368
  async def get_load(self) -> dict:
1331
1369
  # TODO(lsyin): fake load report server
@@ -1576,7 +1614,6 @@ class TokenizerManager:
1576
1614
 
1577
1615
  async def handle_loop(self):
1578
1616
  """The event loop that handles requests"""
1579
-
1580
1617
  while True:
1581
1618
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1582
1619
  self._result_dispatcher(recv_obj)
@@ -1596,9 +1633,12 @@ class TokenizerManager:
1596
1633
  )
1597
1634
  continue
1598
1635
 
1636
+ origin_rid = rid
1637
+ if self.server_args.tokenizer_worker_num > 1:
1638
+ origin_rid = get_origin_rid(rid)
1599
1639
  # Build meta_info and return value
1600
1640
  meta_info = {
1601
- "id": rid,
1641
+ "id": origin_rid,
1602
1642
  "finish_reason": recv_obj.finished_reasons[i],
1603
1643
  "prompt_tokens": recv_obj.prompt_tokens[i],
1604
1644
  "weight_version": self.server_args.weight_version,
@@ -1904,6 +1944,9 @@ class TokenizerManager:
1904
1944
  if is_health_check_generate_req(recv_obj):
1905
1945
  return
1906
1946
  state = self.rid_to_state[recv_obj.rid]
1947
+ origin_rid = recv_obj.rid
1948
+ if self.server_args.tokenizer_worker_num > 1:
1949
+ origin_rid = get_origin_rid(origin_rid)
1907
1950
  state.finished = True
1908
1951
  if recv_obj.finished_reason:
1909
1952
  out = {
@@ -1916,7 +1959,7 @@ class TokenizerManager:
1916
1959
  out = {
1917
1960
  "text": "",
1918
1961
  "meta_info": {
1919
- "id": recv_obj.rid,
1962
+ "id": origin_rid,
1920
1963
  "finish_reason": {
1921
1964
  "type": "abort",
1922
1965
  "message": "Abort before prefill",
@@ -2102,6 +2145,8 @@ T = TypeVar("T")
2102
2145
  class _Communicator(Generic[T]):
2103
2146
  """Note: The communicator now only run up to 1 in-flight request at any time."""
2104
2147
 
2148
+ enable_multi_tokenizer = False
2149
+
2105
2150
  def __init__(self, sender, fan_out: int):
2106
2151
  self._sender = sender
2107
2152
  self._fan_out = fan_out
@@ -2118,6 +2163,8 @@ class _Communicator(Generic[T]):
2118
2163
  assert self._result_values is None
2119
2164
 
2120
2165
  if obj:
2166
+ if _Communicator.enable_multi_tokenizer:
2167
+ obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2121
2168
  self._sender.send_pyobj(obj)
2122
2169
 
2123
2170
  self._result_event = asyncio.Event()
@@ -47,7 +47,7 @@ class ChunkCache(BasePrefixCache):
47
47
  self.req_to_token_pool.free(req.req_pool_idx)
48
48
  self.token_to_kv_pool_allocator.free(kv_indices)
49
49
 
50
- def cache_unfinished_req(self, req: Req):
50
+ def cache_unfinished_req(self, req: Req, chunked=False):
51
51
  kv_indices = self.req_to_token_pool.req_to_token[
52
52
  req.req_pool_idx, : len(req.fill_ids)
53
53
  ]
@@ -102,6 +102,20 @@ class HiCacheStorage(ABC):
102
102
  """
103
103
  pass
104
104
 
105
+ @abstractmethod
106
+ def delete(self, key: str) -> bool:
107
+ """
108
+ Delete the entry associated with the given key.
109
+ """
110
+ pass
111
+
112
+ @abstractmethod
113
+ def clear(self) -> bool:
114
+ """
115
+ Clear all entries in the storage.
116
+ """
117
+ pass
118
+
105
119
  def batch_exists(self, keys: List[str]) -> int:
106
120
  """
107
121
  Check if the keys exist in the storage.
@@ -175,11 +189,12 @@ class HiCacheFile(HiCacheStorage):
175
189
  target_location: Optional[Any] = None,
176
190
  target_sizes: Optional[Any] = None,
177
191
  ) -> bool:
178
- key = self._get_suffixed_key(key)
179
- tensor_path = os.path.join(self.file_path, f"{key}.bin")
180
192
  if self.exists(key):
181
193
  logger.debug(f"Key {key} already exists. Skipped.")
182
194
  return True
195
+
196
+ key = self._get_suffixed_key(key)
197
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
183
198
  try:
184
199
  value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
185
200
  return True
@@ -213,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
213
228
  logger.warning(f"Key {key} does not exist. Cannot delete.")
214
229
  return
215
230
 
216
- def clear(self) -> None:
231
+ def clear(self) -> bool:
217
232
  try:
218
233
  for filename in os.listdir(self.file_path):
219
234
  file_path = os.path.join(self.file_path, filename)
220
235
  if os.path.isfile(file_path):
221
236
  os.remove(file_path)
222
237
  logger.info("Cleared all entries in HiCacheFile storage.")
238
+ return True
223
239
  except Exception as e:
224
240
  logger.error(f"Failed to clear HiCacheFile storage: {e}")
241
+ return False
@@ -102,10 +102,7 @@ class HiRadixCache(RadixCache):
102
102
  self.ongoing_backup = {}
103
103
  # todo: dynamically adjust the threshold
104
104
  self.write_through_threshold = (
105
- 1 if hicache_write_policy == "write_through" else 3
106
- )
107
- self.write_through_threshold_storage = (
108
- 1 if hicache_write_policy == "write_through" else 3
105
+ 1 if hicache_write_policy == "write_through" else 2
109
106
  )
110
107
  self.load_back_threshold = 10
111
108
  super().__init__(
@@ -125,6 +122,15 @@ class HiRadixCache(RadixCache):
125
122
  height += 1
126
123
  return height
127
124
 
125
+ def clear_storage_backend(self):
126
+ if self.enable_storage:
127
+ self.cache_controller.storage_backend.clear()
128
+ logger.info("Hierarchical cache storage backend cleared successfully!")
129
+ return True
130
+ else:
131
+ logger.warning("Hierarchical cache storage backend is not enabled.")
132
+ return False
133
+
128
134
  def write_backup(self, node: TreeNode, write_back=False):
129
135
  host_indices = self.cache_controller.write(
130
136
  device_indices=node.value,
@@ -155,8 +161,9 @@ class HiRadixCache(RadixCache):
155
161
  self.ongoing_backup[operation_id] = node
156
162
  node.protect_host()
157
163
 
158
- def inc_hit_count(self, node: TreeNode):
159
- if self.cache_controller.write_policy == "write_back":
164
+ def _inc_hit_count(self, node: TreeNode, chunked=False):
165
+ # skip the hit count update for chunked requests
166
+ if self.cache_controller.write_policy == "write_back" or chunked:
160
167
  return
161
168
  node.hit_count += 1
162
169
 
@@ -164,14 +171,6 @@ class HiRadixCache(RadixCache):
164
171
  if node.hit_count >= self.write_through_threshold:
165
172
  # write to host if the node is not backuped
166
173
  self.write_backup(node)
167
- else:
168
- if (
169
- self.enable_storage
170
- and (not node.backuped_storage)
171
- and node.hit_count >= self.write_through_threshold_storage
172
- ):
173
- # if the node is backuped on host memory but not on storage
174
- self.write_backup_storage(node)
175
174
 
176
175
  def writing_check(self, write_back=False):
177
176
  if write_back:
@@ -192,8 +191,11 @@ class HiRadixCache(RadixCache):
192
191
  )
193
192
  for _ in range(queue_size.item()):
194
193
  ack_id = self.cache_controller.ack_write_queue.get()
195
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
194
+ backuped_node = self.ongoing_write_through[ack_id]
195
+ self.dec_lock_ref(backuped_node)
196
196
  del self.ongoing_write_through[ack_id]
197
+ if self.enable_storage:
198
+ self.write_backup_storage(backuped_node)
197
199
 
198
200
  def loading_check(self):
199
201
  while not self.cache_controller.ack_load_queue.empty():
@@ -376,57 +378,54 @@ class HiRadixCache(RadixCache):
376
378
  self.writing_check()
377
379
  self.loading_check()
378
380
  if self.enable_storage:
379
- self.check_revoked_prefetch()
380
- self.check_backup_progress()
381
-
382
- def check_revoked_prefetch(self):
383
- queue_size = torch.tensor(
384
- self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
381
+ self.drain_storage_control_queues()
382
+
383
+ def drain_storage_control_queues(self):
384
+ """
385
+ Combine prefetch revoke, backup ack, and host mem release checks
386
+ to minimize TP synchronization and Python overhead.
387
+ """
388
+ cc = self.cache_controller
389
+
390
+ qsizes = torch.tensor(
391
+ [
392
+ cc.prefetch_revoke_queue.qsize(),
393
+ cc.ack_backup_queue.qsize(),
394
+ cc.host_mem_release_queue.qsize(),
395
+ ],
396
+ dtype=torch.int,
385
397
  )
386
398
  if self.tp_world_size > 1:
387
- # synchrnoize TP workers to make the same update to hiradix cache
388
399
  torch.distributed.all_reduce(
389
- queue_size,
390
- op=torch.distributed.ReduceOp.MIN,
391
- group=self.tp_group,
400
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
392
401
  )
393
- for _ in range(queue_size.item()):
394
- req_id = self.cache_controller.prefetch_revoke_queue.get()
395
- if req_id in self.ongoing_prefetch:
396
- last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
397
- last_host_node.release_host()
398
- del self.ongoing_prefetch[req_id]
399
- self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
400
- else:
401
- # the revoked operation already got terminated
402
- pass
403
402
 
404
- def check_backup_progress(self):
405
- queue_size = torch.tensor(
406
- self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
407
- )
408
- if self.tp_world_size > 1:
409
- # synchrnoize TP workers to make the same update to hiradix cache
410
- torch.distributed.all_reduce(
411
- queue_size,
412
- op=torch.distributed.ReduceOp.MIN,
413
- group=self.tp_group,
414
- )
415
- for _ in range(queue_size.item()):
416
- ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
417
- host_node = self.ongoing_backup[ack_id]
418
-
419
- if completed_tokens > 0:
420
- if completed_tokens < len(host_node.key):
421
- # backup is only partially successful, split the node
422
- new_node = self._split_node(
423
- host_node.key, host_node, completed_tokens
424
- )
425
- new_node.backuped_storage = True
426
- else:
427
- host_node.backuped_storage = True
428
- host_node.release_host()
429
- del self.ongoing_backup[ack_id]
403
+ n_revoke, n_backup, n_release = map(int, qsizes.tolist())
404
+
405
+ # process prefetch revokes
406
+ for _ in range(n_revoke):
407
+ req_id = cc.prefetch_revoke_queue.get()
408
+ info = self.ongoing_prefetch.pop(req_id, None)
409
+ if info is not None:
410
+ last_host_node, token_ids, _, _ = info
411
+ last_host_node.release_host()
412
+ cc.prefetch_tokens_occupied -= len(token_ids)
413
+ # else: the revoked operation already got terminated, nothing to do
414
+
415
+ # process backup acks
416
+ for _ in range(n_backup):
417
+ ack_id = cc.ack_backup_queue.get()
418
+ entry = self.ongoing_backup.pop(ack_id, None)
419
+ if entry is not None:
420
+ entry.release_host()
421
+
422
+ # release host memory
423
+ host_indices_list = []
424
+ for _ in range(n_release):
425
+ host_indices_list.append(cc.host_mem_release_queue.get())
426
+ if host_indices_list:
427
+ host_indices = torch.cat(host_indices_list, dim=0)
428
+ cc.mem_pool_host.free(host_indices)
430
429
 
431
430
  def can_terminate_prefetch(self, operation: PrefetchOperation):
432
431
  can_terminate = True
@@ -509,7 +508,7 @@ class HiRadixCache(RadixCache):
509
508
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
510
509
 
511
510
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
512
- self.cache_controller.mem_pool_host.free(
511
+ self.cache_controller.append_host_mem_release(
513
512
  host_indices[min_completed_tokens:completed_tokens]
514
513
  )
515
514
  last_host_node.release_host()
@@ -565,7 +564,11 @@ class HiRadixCache(RadixCache):
565
564
  len(new_input_tokens) % self.page_size
566
565
  )
567
566
  new_input_tokens = new_input_tokens[:prefetch_length]
568
- if not self.enable_storage or prefetch_length < self.prefetch_threshold:
567
+ if (
568
+ not self.enable_storage
569
+ or prefetch_length < self.prefetch_threshold
570
+ or self.cache_controller.prefetch_rate_limited()
571
+ ):
569
572
  return
570
573
 
571
574
  last_host_node.protect_host()
@@ -573,6 +576,10 @@ class HiRadixCache(RadixCache):
573
576
  if host_indices is None:
574
577
  self.evict_host(prefetch_length)
575
578
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
579
+ if host_indices is None:
580
+ last_host_node.release_host()
581
+ # no sufficient host memory for prefetch
582
+ return
576
583
  operation = self.cache_controller.prefetch(
577
584
  req_id, host_indices, new_input_tokens, last_hash
578
585
  )
@@ -672,11 +679,11 @@ class HiRadixCache(RadixCache):
672
679
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
673
680
  return new_node
674
681
 
675
- def _insert_helper(self, node: TreeNode, key: List, value):
676
- node.last_access_time = time.monotonic()
682
+ def insert(self, key: List, value, chunked=False):
677
683
  if len(key) == 0:
678
684
  return 0
679
685
 
686
+ node = self.root_node
680
687
  child_key = self.get_child_key_fn(key)
681
688
  total_prefix_length = 0
682
689
 
@@ -693,7 +700,7 @@ class HiRadixCache(RadixCache):
693
700
  self.token_to_kv_pool_host.update_synced(node.host_value)
694
701
  self.evictable_size_ += len(node.value)
695
702
  else:
696
- self.inc_hit_count(node)
703
+ self._inc_hit_count(node, chunked)
697
704
  total_prefix_length += prefix_len
698
705
  else:
699
706
  # partial match, split the node
@@ -703,7 +710,7 @@ class HiRadixCache(RadixCache):
703
710
  self.token_to_kv_pool_host.update_synced(new_node.host_value)
704
711
  self.evictable_size_ += len(new_node.value)
705
712
  else:
706
- self.inc_hit_count(new_node)
713
+ self._inc_hit_count(new_node, chunked)
707
714
  total_prefix_length += prefix_len
708
715
  node = new_node
709
716
 
@@ -737,7 +744,7 @@ class HiRadixCache(RadixCache):
737
744
  last_hash = new_node.hash_value[-1]
738
745
 
739
746
  if self.cache_controller.write_policy != "write_back":
740
- self.inc_hit_count(new_node)
747
+ self._inc_hit_count(new_node, chunked)
741
748
  return total_prefix_length
742
749
 
743
750
  def _collect_leaves_device(self):
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
183
183
  self.req_to_token_pool.free(req.req_pool_idx)
184
184
  self.dec_lock_ref(req.last_node)
185
185
 
186
- def cache_unfinished_req(self, req: Req):
186
+ def cache_unfinished_req(self, req: Req, chunked=False):
187
187
  """Cache request when it is unfinished."""
188
188
  if self.disable:
189
189
  return
@@ -918,6 +918,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
918
918
  layer_num,
919
919
  self.size // self.page_size + 1,
920
920
  self.page_size,
921
+ 1,
921
922
  self.kv_lora_rank,
922
923
  ),
923
924
  dtype=self.store_dtype,
@@ -928,6 +929,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
928
929
  layer_num,
929
930
  self.size // self.page_size + 1,
930
931
  self.page_size,
932
+ 1,
931
933
  self.qk_rope_head_dim,
932
934
  ),
933
935
  dtype=self.store_dtype,
@@ -1000,9 +1002,11 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1000
1002
  layer_id = layer.layer_id
1001
1003
  if cache_k.dtype != self.dtype:
1002
1004
  cache_k = cache_k.to(self.dtype)
1005
+ cache_v = cache_v.to(self.dtype)
1003
1006
 
1004
1007
  if self.store_dtype != self.dtype:
1005
1008
  cache_k = cache_k.view(self.store_dtype)
1009
+ cache_v = cache_v.view(self.store_dtype)
1006
1010
 
1007
1011
  if cache_v is None:
1008
1012
  cache_k, cache_v = cache_k.split(