sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import hashlib
2
2
  import logging
3
3
  import os
4
4
  from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
5
6
  from typing import Any, List, Optional
6
7
 
7
8
  import torch
@@ -9,17 +10,6 @@ import torch
9
10
  logger = logging.getLogger(__name__)
10
11
 
11
12
 
12
- from sglang.srt.distributed import (
13
- get_tensor_model_parallel_rank,
14
- get_tensor_model_parallel_world_size,
15
- )
16
- from sglang.srt.layers.dp_attention import (
17
- get_attention_tp_rank,
18
- get_attention_tp_size,
19
- is_dp_attention_enabled,
20
- )
21
-
22
-
23
13
  def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
24
14
  hasher = hashlib.sha256()
25
15
 
@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
32
22
  return hasher.hexdigest()
33
23
 
34
24
 
25
+ @dataclass
26
+ class HiCacheStorageConfig:
27
+ tp_rank: int
28
+ tp_size: int
29
+ is_mla_model: bool
30
+ model_name: Optional[str]
31
+ extra_config: Optional[dict] = None
32
+
33
+
35
34
  class HiCacheStorage(ABC):
36
35
  """
37
36
  HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
@@ -60,7 +59,7 @@ class HiCacheStorage(ABC):
60
59
  keys: List[str],
61
60
  target_locations: Optional[Any] = None,
62
61
  target_sizes: Optional[Any] = None,
63
- ) -> List[torch.Tensor | None]:
62
+ ) -> List[torch.Tensor | None] | int:
64
63
  """
65
64
  Retrieve values for multiple keys.
66
65
  Returns a list of tensors or None for each key.
@@ -96,25 +95,51 @@ class HiCacheStorage(ABC):
96
95
  pass
97
96
 
98
97
  @abstractmethod
99
- def exists(self, key: str) -> bool | dict:
98
+ def exists(self, key: str) -> bool:
100
99
  """
101
100
  Check if the key exists in the storage.
102
101
  Returns True if the key exists, False otherwise.
103
102
  """
104
103
  pass
105
104
 
105
+ @abstractmethod
106
+ def delete(self, key: str) -> bool:
107
+ """
108
+ Delete the entry associated with the given key.
109
+ """
110
+ pass
111
+
112
+ @abstractmethod
113
+ def clear(self) -> bool:
114
+ """
115
+ Clear all entries in the storage.
116
+ """
117
+ pass
118
+
119
+ def batch_exists(self, keys: List[str]) -> int:
120
+ """
121
+ Check if the keys exist in the storage.
122
+ return the number of consecutive existing keys from the start.
123
+ Can be overridden by subclasses for more efficient implementation.
124
+ """
125
+ for i in range(len(keys)):
126
+ if not self.exists(keys[i]):
127
+ return i
128
+ return len(keys)
129
+
106
130
 
107
131
  class HiCacheFile(HiCacheStorage):
108
132
 
109
- def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
133
+ def __init__(
134
+ self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
135
+ ):
110
136
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
111
- if is_dp_attention_enabled():
112
- tp_rank = get_attention_tp_rank()
113
- tp_size = get_attention_tp_size()
114
- else:
115
- tp_rank = get_tensor_model_parallel_rank()
116
- tp_size = get_tensor_model_parallel_world_size()
117
137
 
138
+ tp_rank, tp_size, is_mla = (
139
+ storage_config.tp_rank,
140
+ storage_config.tp_size,
141
+ storage_config.is_mla_model,
142
+ )
118
143
  self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
119
144
  if not os.path.exists(self.file_path) and tp_rank == 0:
120
145
  os.makedirs(self.file_path)
@@ -164,11 +189,12 @@ class HiCacheFile(HiCacheStorage):
164
189
  target_location: Optional[Any] = None,
165
190
  target_sizes: Optional[Any] = None,
166
191
  ) -> bool:
167
- key = self._get_suffixed_key(key)
168
- tensor_path = os.path.join(self.file_path, f"{key}.bin")
169
192
  if self.exists(key):
170
193
  logger.debug(f"Key {key} already exists. Skipped.")
171
194
  return True
195
+
196
+ key = self._get_suffixed_key(key)
197
+ tensor_path = os.path.join(self.file_path, f"{key}.bin")
172
198
  try:
173
199
  value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
174
200
  return True
@@ -202,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
202
228
  logger.warning(f"Key {key} does not exist. Cannot delete.")
203
229
  return
204
230
 
205
- def clear(self) -> None:
231
+ def clear(self) -> bool:
206
232
  try:
207
233
  for filename in os.listdir(self.file_path):
208
234
  file_path = os.path.join(self.file_path, filename)
209
235
  if os.path.isfile(file_path):
210
236
  os.remove(file_path)
211
237
  logger.info("Cleared all entries in HiCacheFile storage.")
238
+ return True
212
239
  except Exception as e:
213
240
  logger.error(f"Failed to clear HiCacheFile storage: {e}")
241
+ return False
@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache):
39
39
  hicache_mem_layout: str,
40
40
  hicache_storage_backend: Optional[str] = None,
41
41
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
+ model_name: Optional[str] = None,
43
+ storage_backend_extra_config: Optional[str] = None,
42
44
  ):
43
45
 
44
46
  if hicache_io_backend == "direct":
@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache):
87
89
  io_backend=hicache_io_backend,
88
90
  storage_backend=hicache_storage_backend,
89
91
  prefetch_threshold=self.prefetch_threshold,
92
+ model_name=model_name,
93
+ storage_backend_extra_config=storage_backend_extra_config,
90
94
  )
91
95
 
92
96
  # record the nodes with ongoing write through
@@ -98,10 +102,7 @@ class HiRadixCache(RadixCache):
98
102
  self.ongoing_backup = {}
99
103
  # todo: dynamically adjust the threshold
100
104
  self.write_through_threshold = (
101
- 1 if hicache_write_policy == "write_through" else 3
102
- )
103
- self.write_through_threshold_storage = (
104
- 1 if hicache_write_policy == "write_through" else 3
105
+ 1 if hicache_write_policy == "write_through" else 2
105
106
  )
106
107
  self.load_back_threshold = 10
107
108
  super().__init__(
@@ -121,6 +122,15 @@ class HiRadixCache(RadixCache):
121
122
  height += 1
122
123
  return height
123
124
 
125
+ def clear_storage_backend(self):
126
+ if self.enable_storage:
127
+ self.cache_controller.storage_backend.clear()
128
+ logger.info("Hierarchical cache storage backend cleared successfully!")
129
+ return True
130
+ else:
131
+ logger.warning("Hierarchical cache storage backend is not enabled.")
132
+ return False
133
+
124
134
  def write_backup(self, node: TreeNode, write_back=False):
125
135
  host_indices = self.cache_controller.write(
126
136
  device_indices=node.value,
@@ -151,8 +161,9 @@ class HiRadixCache(RadixCache):
151
161
  self.ongoing_backup[operation_id] = node
152
162
  node.protect_host()
153
163
 
154
- def inc_hit_count(self, node: TreeNode):
155
- if self.cache_controller.write_policy == "write_back":
164
+ def _inc_hit_count(self, node: TreeNode, chunked=False):
165
+ # skip the hit count update for chunked requests
166
+ if self.cache_controller.write_policy == "write_back" or chunked:
156
167
  return
157
168
  node.hit_count += 1
158
169
 
@@ -160,14 +171,6 @@ class HiRadixCache(RadixCache):
160
171
  if node.hit_count >= self.write_through_threshold:
161
172
  # write to host if the node is not backuped
162
173
  self.write_backup(node)
163
- else:
164
- if (
165
- self.enable_storage
166
- and (not node.backuped_storage)
167
- and node.hit_count >= self.write_through_threshold_storage
168
- ):
169
- # if the node is backuped on host memory but not on storage
170
- self.write_backup_storage(node)
171
174
 
172
175
  def writing_check(self, write_back=False):
173
176
  if write_back:
@@ -188,8 +191,11 @@ class HiRadixCache(RadixCache):
188
191
  )
189
192
  for _ in range(queue_size.item()):
190
193
  ack_id = self.cache_controller.ack_write_queue.get()
191
- self.dec_lock_ref(self.ongoing_write_through[ack_id])
194
+ backuped_node = self.ongoing_write_through[ack_id]
195
+ self.dec_lock_ref(backuped_node)
192
196
  del self.ongoing_write_through[ack_id]
197
+ if self.enable_storage:
198
+ self.write_backup_storage(backuped_node)
193
199
 
194
200
  def loading_check(self):
195
201
  while not self.cache_controller.ack_load_queue.empty():
@@ -372,57 +378,54 @@ class HiRadixCache(RadixCache):
372
378
  self.writing_check()
373
379
  self.loading_check()
374
380
  if self.enable_storage:
375
- self.check_revoked_prefetch()
376
- self.check_backup_progress()
377
-
378
- def check_revoked_prefetch(self):
379
- queue_size = torch.tensor(
380
- self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
381
+ self.drain_storage_control_queues()
382
+
383
+ def drain_storage_control_queues(self):
384
+ """
385
+ Combine prefetch revoke, backup ack, and host mem release checks
386
+ to minimize TP synchronization and Python overhead.
387
+ """
388
+ cc = self.cache_controller
389
+
390
+ qsizes = torch.tensor(
391
+ [
392
+ cc.prefetch_revoke_queue.qsize(),
393
+ cc.ack_backup_queue.qsize(),
394
+ cc.host_mem_release_queue.qsize(),
395
+ ],
396
+ dtype=torch.int,
381
397
  )
382
398
  if self.tp_world_size > 1:
383
- # synchrnoize TP workers to make the same update to hiradix cache
384
399
  torch.distributed.all_reduce(
385
- queue_size,
386
- op=torch.distributed.ReduceOp.MIN,
387
- group=self.tp_group,
400
+ qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
388
401
  )
389
- for _ in range(queue_size.item()):
390
- req_id = self.cache_controller.prefetch_revoke_queue.get()
391
- if req_id in self.ongoing_prefetch:
392
- last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
393
- last_host_node.release_host()
394
- del self.ongoing_prefetch[req_id]
395
- self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
396
- else:
397
- # the revoked operation already got terminated
398
- pass
399
402
 
400
- def check_backup_progress(self):
401
- queue_size = torch.tensor(
402
- self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
403
- )
404
- if self.tp_world_size > 1:
405
- # synchrnoize TP workers to make the same update to hiradix cache
406
- torch.distributed.all_reduce(
407
- queue_size,
408
- op=torch.distributed.ReduceOp.MIN,
409
- group=self.tp_group,
410
- )
411
- for _ in range(queue_size.item()):
412
- ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
413
- host_node = self.ongoing_backup[ack_id]
414
-
415
- if completed_tokens > 0:
416
- if completed_tokens < len(host_node.key):
417
- # backup is only partially successful, split the node
418
- new_node = self._split_node(
419
- host_node.key, host_node, completed_tokens
420
- )
421
- new_node.backuped_storage = True
422
- else:
423
- host_node.backuped_storage = True
424
- host_node.release_host()
425
- del self.ongoing_backup[ack_id]
403
+ n_revoke, n_backup, n_release = map(int, qsizes.tolist())
404
+
405
+ # process prefetch revokes
406
+ for _ in range(n_revoke):
407
+ req_id = cc.prefetch_revoke_queue.get()
408
+ info = self.ongoing_prefetch.pop(req_id, None)
409
+ if info is not None:
410
+ last_host_node, token_ids, _, _ = info
411
+ last_host_node.release_host()
412
+ cc.prefetch_tokens_occupied -= len(token_ids)
413
+ # else: the revoked operation already got terminated, nothing to do
414
+
415
+ # process backup acks
416
+ for _ in range(n_backup):
417
+ ack_id = cc.ack_backup_queue.get()
418
+ entry = self.ongoing_backup.pop(ack_id, None)
419
+ if entry is not None:
420
+ entry.release_host()
421
+
422
+ # release host memory
423
+ host_indices_list = []
424
+ for _ in range(n_release):
425
+ host_indices_list.append(cc.host_mem_release_queue.get())
426
+ if host_indices_list:
427
+ host_indices = torch.cat(host_indices_list, dim=0)
428
+ cc.mem_pool_host.free(host_indices)
426
429
 
427
430
  def can_terminate_prefetch(self, operation: PrefetchOperation):
428
431
  can_terminate = True
@@ -430,9 +433,12 @@ class HiRadixCache(RadixCache):
430
433
  if self.prefetch_stop_policy == "best_effort":
431
434
  return can_terminate
432
435
 
433
- completed = (
434
- operation.completed_tokens == len(operation.hash_value) * self.page_size
435
- )
436
+ if len(operation.hash_value) == 0:
437
+ completed = False
438
+ else:
439
+ completed = (
440
+ operation.completed_tokens == len(operation.hash_value) * self.page_size
441
+ )
436
442
 
437
443
  if self.prefetch_stop_policy == "wait_complete":
438
444
  can_terminate = completed
@@ -502,7 +508,7 @@ class HiRadixCache(RadixCache):
502
508
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
503
509
 
504
510
  self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
505
- self.cache_controller.mem_pool_host.free(
511
+ self.cache_controller.append_host_mem_release(
506
512
  host_indices[min_completed_tokens:completed_tokens]
507
513
  )
508
514
  last_host_node.release_host()
@@ -536,6 +542,8 @@ class HiRadixCache(RadixCache):
536
542
  while last_node.evicted:
537
543
  host_hit_length += len(last_node.host_value)
538
544
  last_node = last_node.parent
545
+ while not last_host_node.backuped:
546
+ last_host_node = last_host_node.parent
539
547
 
540
548
  return MatchResult(
541
549
  device_indices=value,
@@ -556,7 +564,11 @@ class HiRadixCache(RadixCache):
556
564
  len(new_input_tokens) % self.page_size
557
565
  )
558
566
  new_input_tokens = new_input_tokens[:prefetch_length]
559
- if not self.enable_storage or prefetch_length < self.prefetch_threshold:
567
+ if (
568
+ not self.enable_storage
569
+ or prefetch_length < self.prefetch_threshold
570
+ or self.cache_controller.prefetch_rate_limited()
571
+ ):
560
572
  return
561
573
 
562
574
  last_host_node.protect_host()
@@ -564,6 +576,10 @@ class HiRadixCache(RadixCache):
564
576
  if host_indices is None:
565
577
  self.evict_host(prefetch_length)
566
578
  host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
579
+ if host_indices is None:
580
+ last_host_node.release_host()
581
+ # no sufficient host memory for prefetch
582
+ return
567
583
  operation = self.cache_controller.prefetch(
568
584
  req_id, host_indices, new_input_tokens, last_hash
569
585
  )
@@ -663,11 +679,11 @@ class HiRadixCache(RadixCache):
663
679
  new_node.parent.children[self.get_child_key_fn(key)] = new_node
664
680
  return new_node
665
681
 
666
- def _insert_helper(self, node: TreeNode, key: List, value):
667
- node.last_access_time = time.monotonic()
682
+ def insert(self, key: List, value, chunked=False):
668
683
  if len(key) == 0:
669
684
  return 0
670
685
 
686
+ node = self.root_node
671
687
  child_key = self.get_child_key_fn(key)
672
688
  total_prefix_length = 0
673
689
 
@@ -684,7 +700,7 @@ class HiRadixCache(RadixCache):
684
700
  self.token_to_kv_pool_host.update_synced(node.host_value)
685
701
  self.evictable_size_ += len(node.value)
686
702
  else:
687
- self.inc_hit_count(node)
703
+ self._inc_hit_count(node, chunked)
688
704
  total_prefix_length += prefix_len
689
705
  else:
690
706
  # partial match, split the node
@@ -694,7 +710,7 @@ class HiRadixCache(RadixCache):
694
710
  self.token_to_kv_pool_host.update_synced(new_node.host_value)
695
711
  self.evictable_size_ += len(new_node.value)
696
712
  else:
697
- self.inc_hit_count(new_node)
713
+ self._inc_hit_count(new_node, chunked)
698
714
  total_prefix_length += prefix_len
699
715
  node = new_node
700
716
 
@@ -728,7 +744,7 @@ class HiRadixCache(RadixCache):
728
744
  last_hash = new_node.hash_value[-1]
729
745
 
730
746
  if self.cache_controller.write_policy != "write_back":
731
- self.inc_hit_count(new_node)
747
+ self._inc_hit_count(new_node, chunked)
732
748
  return total_prefix_length
733
749
 
734
750
  def _collect_leaves_device(self):
@@ -183,7 +183,7 @@ class LoRARadixCache(BasePrefixCache):
183
183
  self.req_to_token_pool.free(req.req_pool_idx)
184
184
  self.dec_lock_ref(req.last_node)
185
185
 
186
- def cache_unfinished_req(self, req: Req):
186
+ def cache_unfinished_req(self, req: Req, chunked=False):
187
187
  """Cache request when it is unfinished."""
188
188
  if self.disable:
189
189
  return
@@ -36,12 +36,15 @@ import triton.language as tl
36
36
 
37
37
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
39
+ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
40
40
 
41
41
  logger = logging.getLogger(__name__)
42
42
 
43
43
  GB = 1024 * 1024 * 1024
44
44
  _is_cuda = is_cuda()
45
+ _is_npu = is_npu()
46
+ if _is_npu:
47
+ import torch_npu
45
48
 
46
49
 
47
50
  class ReqToTokenPool:
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
624
627
  cache_k = cache_k.view(self.store_dtype)
625
628
  cache_v = cache_v.view(self.store_dtype)
626
629
 
627
- import torch_npu
628
-
629
630
  torch_npu._npu_reshape_and_cache(
630
631
  key=cache_k,
631
632
  value=cache_v,
@@ -912,12 +913,24 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
912
913
 
913
914
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
914
915
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
915
- self.kv_buffer = torch.zeros(
916
+ self.k_buffer = torch.zeros(
916
917
  (
917
918
  layer_num,
918
919
  self.size // self.page_size + 1,
919
920
  self.page_size,
920
- self.kv_lora_rank + self.qk_rope_head_dim,
921
+ 1,
922
+ self.kv_lora_rank,
923
+ ),
924
+ dtype=self.store_dtype,
925
+ device=self.device,
926
+ )
927
+ self.v_buffer = torch.zeros(
928
+ (
929
+ layer_num,
930
+ self.size // self.page_size + 1,
931
+ self.page_size,
932
+ 1,
933
+ self.qk_rope_head_dim,
921
934
  ),
922
935
  dtype=self.store_dtype,
923
936
  device=self.device,
@@ -931,12 +944,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
931
944
  )
932
945
  self.mem_usage = kv_size / GB
933
946
 
947
+ def get_kv_size_bytes(self):
948
+ assert hasattr(self, "k_buffer")
949
+ assert hasattr(self, "v_buffer")
950
+ kv_size_bytes = 0
951
+ for k_cache in self.k_buffer:
952
+ kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
953
+ for v_cache in self.v_buffer:
954
+ kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
955
+ return kv_size_bytes
956
+
957
+ def get_kv_buffer(self, layer_id: int):
958
+ if self.layer_transfer_counter is not None:
959
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
960
+ return (
961
+ self.k_buffer[layer_id - self.start_layer],
962
+ self.v_buffer[layer_id - self.start_layer],
963
+ )
964
+
965
+ def get_key_buffer(self, layer_id: int):
966
+ if self.layer_transfer_counter is not None:
967
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
968
+
969
+ if self.store_dtype != self.dtype:
970
+ return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
971
+ return self.k_buffer[layer_id - self.start_layer]
972
+
973
+ def get_value_buffer(self, layer_id: int):
974
+ if self.layer_transfer_counter is not None:
975
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
976
+
977
+ if self.store_dtype != self.dtype:
978
+ return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
979
+ return self.v_buffer[layer_id - self.start_layer]
980
+
934
981
  # for disagg
935
982
  def get_contiguous_buf_infos(self):
936
983
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
937
- kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
938
- kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
939
- kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
984
+ kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
985
+ self.v_buffer[i].data_ptr() for i in range(self.layer_num)
986
+ ]
987
+ kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
988
+ self.v_buffer[i].nbytes for i in range(self.layer_num)
989
+ ]
990
+ kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
991
+ self.v_buffer[i][0].nbytes for i in range(self.layer_num)
992
+ ]
940
993
  return kv_data_ptrs, kv_data_lens, kv_item_lens
941
994
 
942
995
  def set_kv_buffer(
@@ -949,18 +1002,28 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
949
1002
  layer_id = layer.layer_id
950
1003
  if cache_k.dtype != self.dtype:
951
1004
  cache_k = cache_k.to(self.dtype)
1005
+ cache_v = cache_v.to(self.dtype)
952
1006
 
953
1007
  if self.store_dtype != self.dtype:
954
1008
  cache_k = cache_k.view(self.store_dtype)
1009
+ cache_v = cache_v.view(self.store_dtype)
955
1010
 
956
- import torch_npu
1011
+ if cache_v is None:
1012
+ cache_k, cache_v = cache_k.split(
1013
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
1014
+ )
957
1015
 
958
- torch_npu._npu_reshape_and_cache_siso(
959
- key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
960
- key_cache=self.kv_buffer[layer_id - self.start_layer].view(
961
- -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
1016
+ torch_npu.npu_scatter_nd_update_(
1017
+ self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
1018
+ loc.view(-1, 1),
1019
+ cache_k.view(-1, 1, self.kv_lora_rank),
1020
+ )
1021
+ torch_npu.npu_scatter_nd_update_(
1022
+ self.v_buffer[layer_id - self.start_layer].view(
1023
+ -1, 1, self.qk_rope_head_dim
962
1024
  ),
963
- slot_indices=loc,
1025
+ loc.view(-1, 1),
1026
+ cache_v.view(-1, 1, self.qk_rope_head_dim),
964
1027
  )
965
1028
 
966
1029
 
@@ -7,7 +7,6 @@ from functools import wraps
7
7
  import psutil
8
8
  import torch
9
9
 
10
- from sglang.srt.distributed import get_tensor_model_parallel_rank
11
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
12
11
  from sglang.srt.utils import is_npu
13
12
 
@@ -464,7 +463,7 @@ class MHATokenToKVPoolHost(HostKVCache):
464
463
  else:
465
464
  raise ValueError(f"Unsupported layout: {self.layout}")
466
465
 
467
- def get_buffer_meta(self, keys, indices):
466
+ def get_buffer_meta(self, keys, indices, local_rank):
468
467
  ptr_list = []
469
468
  key_list = []
470
469
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
@@ -488,8 +487,8 @@ class MHATokenToKVPoolHost(HostKVCache):
488
487
  ptr_list.append(k_ptr)
489
488
  ptr_list.append(v_ptr)
490
489
  key_ = keys[index // self.page_size]
491
- key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
492
- key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
490
+ key_list.append(f"{key_}_{local_rank}_k")
491
+ key_list.append(f"{key_}_{local_rank}_v")
493
492
  element_size = (
494
493
  self.layer_num
495
494
  * self.dtype.itemsize
@@ -703,7 +702,7 @@ class MLATokenToKVPoolHost(HostKVCache):
703
702
  else:
704
703
  raise ValueError(f"Unsupported layout: {self.layout}")
705
704
 
706
- def get_buffer_meta(self, keys, indices):
705
+ def get_buffer_meta(self, keys, indices, local_rank):
707
706
  ptr_list = []
708
707
  key_list = []
709
708
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
@@ -62,7 +62,6 @@ class TreeNode:
62
62
  self.host_value: Optional[torch.Tensor] = None
63
63
  # store hash values of each pages
64
64
  self.hash_value: Optional[List[str]] = None
65
- self.backuped_storage = False
66
65
 
67
66
  self.id = TreeNode.counter if id is None else id
68
67
  TreeNode.counter += 1
@@ -152,6 +151,7 @@ class RadixCache(BasePrefixCache):
152
151
  self.root_node = TreeNode()
153
152
  self.root_node.key = []
154
153
  self.root_node.value = []
154
+ self.root_node.host_value = []
155
155
  self.root_node.lock_ref = 1
156
156
  self.evictable_size_ = 0
157
157
  self.protected_size_ = 0
@@ -194,7 +194,7 @@ class RadixCache(BasePrefixCache):
194
194
  last_host_node=last_node,
195
195
  )
196
196
 
197
- def insert(self, key: List, value=None):
197
+ def insert(self, key: List, value=None, chunked=False):
198
198
  if self.disable:
199
199
  return 0
200
200
 
@@ -239,7 +239,7 @@ class RadixCache(BasePrefixCache):
239
239
  self.req_to_token_pool.free(req.req_pool_idx)
240
240
  self.dec_lock_ref(req.last_node)
241
241
 
242
- def cache_unfinished_req(self, req: Req):
242
+ def cache_unfinished_req(self, req: Req, chunked=False):
243
243
  """Cache request when it is unfinished."""
244
244
  if self.disable:
245
245
  return
@@ -260,7 +260,9 @@ class RadixCache(BasePrefixCache):
260
260
  page_aligned_token_ids = token_ids[:page_aligned_len]
261
261
 
262
262
  # Radix Cache takes one ref in memory pool
263
- new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
263
+ new_prefix_len = self.insert(
264
+ page_aligned_token_ids, page_aligned_kv_indices, chunked=chunked
265
+ )
264
266
  self.token_to_kv_pool_allocator.free(
265
267
  kv_indices[len(req.prefix_indices) : new_prefix_len]
266
268
  )
@@ -181,7 +181,7 @@ class RadixCacheCpp(BasePrefixCache):
181
181
  self.dec_lock_ref(req.last_node)
182
182
  self.req_to_token_pool.free(req.req_pool_idx)
183
183
 
184
- def cache_unfinished_req(self, req: Req):
184
+ def cache_unfinished_req(self, req: Req, chunked=False):
185
185
  """Cache request when it is unfinished."""
186
186
  assert req.req_pool_idx is not None
187
187
  token_ids = req.fill_ids