sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -67,6 +67,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
67
  from sglang.srt.layers.moe import initialize_moe_config
68
68
  from sglang.srt.managers.io_struct import (
69
69
  AbortReq,
70
+ BatchTokenizedEmbeddingReqInput,
71
+ BatchTokenizedGenerateReqInput,
70
72
  CloseSessionReqInput,
71
73
  ExpertDistributionReq,
72
74
  ExpertDistributionReqOutput,
@@ -510,6 +512,8 @@ class Scheduler(
510
512
  [
511
513
  (TokenizedGenerateReqInput, self.handle_generate_request),
512
514
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
515
+ (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
516
+ (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
513
517
  (FlushCacheReqInput, self.flush_cache_wrapped),
514
518
  (AbortReq, self.abort_request),
515
519
  (OpenSessionReqInput, self.open_session),
@@ -623,6 +627,8 @@ class Scheduler(
623
627
  hicache_mem_layout=server_args.hicache_mem_layout,
624
628
  hicache_storage_backend=server_args.hicache_storage_backend,
625
629
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
630
+ model_name=server_args.served_model_name,
631
+ storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
626
632
  )
627
633
  self.tp_worker.register_hicache_layer_transfer_counter(
628
634
  self.tree_cache.cache_controller.layer_done_counter
@@ -1018,14 +1024,26 @@ class Scheduler(
1018
1024
  req
1019
1025
  for req in recv_reqs
1020
1026
  if isinstance(
1021
- req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1027
+ req,
1028
+ (
1029
+ TokenizedGenerateReqInput,
1030
+ TokenizedEmbeddingReqInput,
1031
+ BatchTokenizedGenerateReqInput,
1032
+ BatchTokenizedEmbeddingReqInput,
1033
+ ),
1022
1034
  )
1023
1035
  ]
1024
1036
  control_reqs = [
1025
1037
  req
1026
1038
  for req in recv_reqs
1027
1039
  if not isinstance(
1028
- req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
1040
+ req,
1041
+ (
1042
+ TokenizedGenerateReqInput,
1043
+ TokenizedEmbeddingReqInput,
1044
+ BatchTokenizedGenerateReqInput,
1045
+ BatchTokenizedEmbeddingReqInput,
1046
+ ),
1029
1047
  )
1030
1048
  ]
1031
1049
  else:
@@ -1253,6 +1271,17 @@ class Scheduler(
1253
1271
  else:
1254
1272
  self._add_request_to_queue(req)
1255
1273
 
1274
+ def handle_batch_generate_request(
1275
+ self,
1276
+ recv_req: BatchTokenizedGenerateReqInput,
1277
+ ):
1278
+ """Handle optimized batch generate request."""
1279
+ logger.debug(f"Processing batch generate request with {len(recv_req)} requests")
1280
+
1281
+ # Process each request in the batch
1282
+ for tokenized_req in recv_req:
1283
+ self.handle_generate_request(tokenized_req)
1284
+
1256
1285
  def _add_request_to_queue(self, req: Req):
1257
1286
  req.queue_time_start = time.perf_counter()
1258
1287
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -1269,10 +1298,11 @@ class Scheduler(
1269
1298
  def _prefetch_kvcache(self, req: Req):
1270
1299
  if self.enable_hicache_storage:
1271
1300
  req.init_next_round_input(self.tree_cache)
1272
- last_hash = req.last_host_node.get_last_hash_value()
1273
- matched_len = len(req.prefix_indices) + req.host_hit_length
1274
- # todo, free-form fetching, calculating hash keys on the fly
1275
- if (matched_len > 0 and last_hash is not None) or matched_len == 0:
1301
+ if req.last_node.backuped:
1302
+ # only to initiate the prefetch if the last node is backuped
1303
+ # otherwise, the allocated GPU memory must be locked for integrity
1304
+ last_hash = req.last_host_node.get_last_hash_value()
1305
+ matched_len = len(req.prefix_indices) + req.host_hit_length
1276
1306
  new_input_tokens = req.fill_ids[matched_len:]
1277
1307
  self.tree_cache.prefetch_from_storage(
1278
1308
  req.rid, req.last_host_node, new_input_tokens, last_hash
@@ -1335,6 +1365,19 @@ class Scheduler(
1335
1365
  req.logprob_start_len = len(req.origin_input_ids) - 1
1336
1366
  self._add_request_to_queue(req)
1337
1367
 
1368
+ def handle_batch_embedding_request(
1369
+ self,
1370
+ recv_req: BatchTokenizedEmbeddingReqInput,
1371
+ ):
1372
+ """Handle optimized batch embedding request."""
1373
+ logger.debug(
1374
+ f"Processing batch embedding request with {len(recv_req)} requests"
1375
+ )
1376
+
1377
+ # Process each request in the batch
1378
+ for tokenized_req in recv_req:
1379
+ self.handle_embedding_request(tokenized_req)
1380
+
1338
1381
  def self_check_during_idle(self):
1339
1382
  self.check_memory()
1340
1383
  self.check_tree_cache()
@@ -2513,7 +2556,15 @@ def is_health_check_generate_req(recv_req):
2513
2556
 
2514
2557
 
2515
2558
  def is_work_request(recv_req):
2516
- return isinstance(recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput))
2559
+ return isinstance(
2560
+ recv_req,
2561
+ (
2562
+ TokenizedGenerateReqInput,
2563
+ TokenizedEmbeddingReqInput,
2564
+ BatchTokenizedGenerateReqInput,
2565
+ BatchTokenizedEmbeddingReqInput,
2566
+ ),
2567
+ )
2517
2568
 
2518
2569
 
2519
2570
  def run_scheduler_process(
@@ -125,6 +125,14 @@ class SchedulerMetricsMixin:
125
125
  total_queue_latency += req.queue_time_end - req.queue_time_start
126
126
  self.stats.avg_request_queue_latency = total_queue_latency / num_new_seq
127
127
 
128
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
129
+ self.stats.num_prefill_prealloc_queue_reqs = len(
130
+ self.disagg_prefill_bootstrap_queue.queue
131
+ )
132
+ self.stats.num_prefill_inflight_queue_reqs = len(
133
+ self.disagg_prefill_inflight_queue
134
+ )
135
+
128
136
  self.metrics_collector.log_stats(self.stats)
129
137
  self._emit_kv_metrics()
130
138
  self._publish_kv_events()
@@ -202,6 +210,13 @@ class SchedulerMetricsMixin:
202
210
  self.stats.spec_accept_length = spec_accept_length
203
211
  self.stats.total_retracted_reqs = self.total_retracted_reqs
204
212
  self.metrics_collector.log_stats(self.stats)
213
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
214
+ self.stats.num_decode_prealloc_queue_reqs = len(
215
+ self.disagg_decode_prealloc_queue.queue
216
+ )
217
+ self.stats.num_decode_transfer_queue_reqs = len(
218
+ self.disagg_decode_transfer_queue.queue
219
+ )
205
220
  self._emit_kv_metrics()
206
221
  self._publish_kv_events()
207
222
 
@@ -71,6 +71,8 @@ from sglang.srt.managers.io_struct import (
71
71
  BatchMultimodalOut,
72
72
  BatchStrOut,
73
73
  BatchTokenIDOut,
74
+ BatchTokenizedEmbeddingReqInput,
75
+ BatchTokenizedGenerateReqInput,
74
76
  CloseSessionReqInput,
75
77
  ConfigureLoggingReq,
76
78
  EmbeddingReqInput,
@@ -768,6 +770,30 @@ class TokenizerManager:
768
770
  self.rid_to_state[obj.rid] = state
769
771
  return state
770
772
 
773
+ def _send_batch_request(
774
+ self,
775
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
776
+ tokenized_objs: List[
777
+ Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]
778
+ ],
779
+ created_time: Optional[float] = None,
780
+ ):
781
+ """Send a batch of tokenized requests as a single batched request to the scheduler."""
782
+ if isinstance(tokenized_objs[0], TokenizedGenerateReqInput):
783
+ batch_req = BatchTokenizedGenerateReqInput(batch=tokenized_objs)
784
+ else:
785
+ batch_req = BatchTokenizedEmbeddingReqInput(batch=tokenized_objs)
786
+
787
+ self.send_to_scheduler.send_pyobj(batch_req)
788
+
789
+ # Create states for each individual request in the batch
790
+ for i, tokenized_obj in enumerate(tokenized_objs):
791
+ tmp_obj = obj[i]
792
+ state = ReqState(
793
+ [], False, asyncio.Event(), tmp_obj, created_time=created_time
794
+ )
795
+ self.rid_to_state[tmp_obj.rid] = state
796
+
771
797
  async def _wait_one_response(
772
798
  self,
773
799
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -870,10 +896,17 @@ class TokenizerManager:
870
896
 
871
897
  tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
872
898
 
873
- for i, tokenized_obj in enumerate(tokenized_objs):
899
+ # Send as a single batched request
900
+ self._send_batch_request(obj, tokenized_objs, created_time)
901
+
902
+ # Set up generators for each request in the batch
903
+ for i in range(batch_size):
874
904
  tmp_obj = obj[i]
875
- state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
876
- generators.append(self._wait_one_response(tmp_obj, state, request))
905
+ generators.append(
906
+ self._wait_one_response(
907
+ tmp_obj, self.rid_to_state[tmp_obj.rid], request
908
+ )
909
+ )
877
910
  rids.append(tmp_obj.rid)
878
911
  else:
879
912
  # Sequential tokenization and processing
@@ -2,6 +2,7 @@ import hashlib
2
2
  import logging
3
3
  import os
4
4
  from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
5
6
  from typing import Any, List, Optional
6
7
 
7
8
  import torch
@@ -9,17 +10,6 @@ import torch
9
10
  logger = logging.getLogger(__name__)
10
11
 
11
12
 
12
- from sglang.srt.distributed import (
13
- get_tensor_model_parallel_rank,
14
- get_tensor_model_parallel_world_size,
15
- )
16
- from sglang.srt.layers.dp_attention import (
17
- get_attention_tp_rank,
18
- get_attention_tp_size,
19
- is_dp_attention_enabled,
20
- )
21
-
22
-
23
13
  def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
24
14
  hasher = hashlib.sha256()
25
15
 
@@ -32,6 +22,15 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
32
22
  return hasher.hexdigest()
33
23
 
34
24
 
25
+ @dataclass
26
+ class HiCacheStorageConfig:
27
+ tp_rank: int
28
+ tp_size: int
29
+ is_mla_model: bool
30
+ model_name: Optional[str]
31
+ extra_config: Optional[dict] = None
32
+
33
+
35
34
  class HiCacheStorage(ABC):
36
35
  """
37
36
  HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
@@ -60,7 +59,7 @@ class HiCacheStorage(ABC):
60
59
  keys: List[str],
61
60
  target_locations: Optional[Any] = None,
62
61
  target_sizes: Optional[Any] = None,
63
- ) -> List[torch.Tensor | None]:
62
+ ) -> List[torch.Tensor | None] | int:
64
63
  """
65
64
  Retrieve values for multiple keys.
66
65
  Returns a list of tensors or None for each key.
@@ -96,25 +95,37 @@ class HiCacheStorage(ABC):
96
95
  pass
97
96
 
98
97
  @abstractmethod
99
- def exists(self, key: str) -> bool | dict:
98
+ def exists(self, key: str) -> bool:
100
99
  """
101
100
  Check if the key exists in the storage.
102
101
  Returns True if the key exists, False otherwise.
103
102
  """
104
103
  pass
105
104
 
105
+ def batch_exists(self, keys: List[str]) -> int:
106
+ """
107
+ Check if the keys exist in the storage.
108
+ return the number of consecutive existing keys from the start.
109
+ Can be overridden by subclasses for more efficient implementation.
110
+ """
111
+ for i in range(len(keys)):
112
+ if not self.exists(keys[i]):
113
+ return i
114
+ return len(keys)
115
+
106
116
 
107
117
  class HiCacheFile(HiCacheStorage):
108
118
 
109
- def __init__(self, file_path: str = "/tmp/hicache", is_mla: bool = False):
119
+ def __init__(
120
+ self, storage_config: HiCacheStorageConfig, file_path: str = "/tmp/hicache"
121
+ ):
110
122
  self.file_path = os.getenv("SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR", file_path)
111
- if is_dp_attention_enabled():
112
- tp_rank = get_attention_tp_rank()
113
- tp_size = get_attention_tp_size()
114
- else:
115
- tp_rank = get_tensor_model_parallel_rank()
116
- tp_size = get_tensor_model_parallel_world_size()
117
123
 
124
+ tp_rank, tp_size, is_mla = (
125
+ storage_config.tp_rank,
126
+ storage_config.tp_size,
127
+ storage_config.is_mla_model,
128
+ )
118
129
  self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 and not is_mla else ""
119
130
  if not os.path.exists(self.file_path) and tp_rank == 0:
120
131
  os.makedirs(self.file_path)
@@ -39,6 +39,8 @@ class HiRadixCache(RadixCache):
39
39
  hicache_mem_layout: str,
40
40
  hicache_storage_backend: Optional[str] = None,
41
41
  hicache_storage_prefetch_policy: Optional[str] = "best_effort",
42
+ model_name: Optional[str] = None,
43
+ storage_backend_extra_config: Optional[str] = None,
42
44
  ):
43
45
 
44
46
  if hicache_io_backend == "direct":
@@ -87,6 +89,8 @@ class HiRadixCache(RadixCache):
87
89
  io_backend=hicache_io_backend,
88
90
  storage_backend=hicache_storage_backend,
89
91
  prefetch_threshold=self.prefetch_threshold,
92
+ model_name=model_name,
93
+ storage_backend_extra_config=storage_backend_extra_config,
90
94
  )
91
95
 
92
96
  # record the nodes with ongoing write through
@@ -430,9 +434,12 @@ class HiRadixCache(RadixCache):
430
434
  if self.prefetch_stop_policy == "best_effort":
431
435
  return can_terminate
432
436
 
433
- completed = (
434
- operation.completed_tokens == len(operation.hash_value) * self.page_size
435
- )
437
+ if len(operation.hash_value) == 0:
438
+ completed = False
439
+ else:
440
+ completed = (
441
+ operation.completed_tokens == len(operation.hash_value) * self.page_size
442
+ )
436
443
 
437
444
  if self.prefetch_stop_policy == "wait_complete":
438
445
  can_terminate = completed
@@ -536,6 +543,8 @@ class HiRadixCache(RadixCache):
536
543
  while last_node.evicted:
537
544
  host_hit_length += len(last_node.host_value)
538
545
  last_node = last_node.parent
546
+ while not last_host_node.backuped:
547
+ last_host_node = last_host_node.parent
539
548
 
540
549
  return MatchResult(
541
550
  device_indices=value,
@@ -36,12 +36,15 @@ import triton.language as tl
36
36
 
37
37
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
39
+ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
40
40
 
41
41
  logger = logging.getLogger(__name__)
42
42
 
43
43
  GB = 1024 * 1024 * 1024
44
44
  _is_cuda = is_cuda()
45
+ _is_npu = is_npu()
46
+ if _is_npu:
47
+ import torch_npu
45
48
 
46
49
 
47
50
  class ReqToTokenPool:
@@ -624,8 +627,6 @@ class AscendTokenToKVPool(MHATokenToKVPool):
624
627
  cache_k = cache_k.view(self.store_dtype)
625
628
  cache_v = cache_v.view(self.store_dtype)
626
629
 
627
- import torch_npu
628
-
629
630
  torch_npu._npu_reshape_and_cache(
630
631
  key=cache_k,
631
632
  value=cache_v,
@@ -912,12 +913,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
912
913
 
913
914
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
914
915
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
915
- self.kv_buffer = torch.zeros(
916
+ self.k_buffer = torch.zeros(
916
917
  (
917
918
  layer_num,
918
919
  self.size // self.page_size + 1,
919
920
  self.page_size,
920
- self.kv_lora_rank + self.qk_rope_head_dim,
921
+ self.kv_lora_rank,
922
+ ),
923
+ dtype=self.store_dtype,
924
+ device=self.device,
925
+ )
926
+ self.v_buffer = torch.zeros(
927
+ (
928
+ layer_num,
929
+ self.size // self.page_size + 1,
930
+ self.page_size,
931
+ self.qk_rope_head_dim,
921
932
  ),
922
933
  dtype=self.store_dtype,
923
934
  device=self.device,
@@ -931,12 +942,52 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
931
942
  )
932
943
  self.mem_usage = kv_size / GB
933
944
 
945
+ def get_kv_size_bytes(self):
946
+ assert hasattr(self, "k_buffer")
947
+ assert hasattr(self, "v_buffer")
948
+ kv_size_bytes = 0
949
+ for k_cache in self.k_buffer:
950
+ kv_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize
951
+ for v_cache in self.v_buffer:
952
+ kv_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
953
+ return kv_size_bytes
954
+
955
+ def get_kv_buffer(self, layer_id: int):
956
+ if self.layer_transfer_counter is not None:
957
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
958
+ return (
959
+ self.k_buffer[layer_id - self.start_layer],
960
+ self.v_buffer[layer_id - self.start_layer],
961
+ )
962
+
963
+ def get_key_buffer(self, layer_id: int):
964
+ if self.layer_transfer_counter is not None:
965
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
966
+
967
+ if self.store_dtype != self.dtype:
968
+ return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
969
+ return self.k_buffer[layer_id - self.start_layer]
970
+
971
+ def get_value_buffer(self, layer_id: int):
972
+ if self.layer_transfer_counter is not None:
973
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
974
+
975
+ if self.store_dtype != self.dtype:
976
+ return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
977
+ return self.v_buffer[layer_id - self.start_layer]
978
+
934
979
  # for disagg
935
980
  def get_contiguous_buf_infos(self):
936
981
  # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
937
- kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
938
- kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
939
- kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
982
+ kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
983
+ self.v_buffer[i].data_ptr() for i in range(self.layer_num)
984
+ ]
985
+ kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
986
+ self.v_buffer[i].nbytes for i in range(self.layer_num)
987
+ ]
988
+ kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
989
+ self.v_buffer[i][0].nbytes for i in range(self.layer_num)
990
+ ]
940
991
  return kv_data_ptrs, kv_data_lens, kv_item_lens
941
992
 
942
993
  def set_kv_buffer(
@@ -953,14 +1004,22 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
953
1004
  if self.store_dtype != self.dtype:
954
1005
  cache_k = cache_k.view(self.store_dtype)
955
1006
 
956
- import torch_npu
1007
+ if cache_v is None:
1008
+ cache_k, cache_v = cache_k.split(
1009
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
1010
+ )
957
1011
 
958
- torch_npu._npu_reshape_and_cache_siso(
959
- key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
960
- key_cache=self.kv_buffer[layer_id - self.start_layer].view(
961
- -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
1012
+ torch_npu.npu_scatter_nd_update_(
1013
+ self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
1014
+ loc.view(-1, 1),
1015
+ cache_k.view(-1, 1, self.kv_lora_rank),
1016
+ )
1017
+ torch_npu.npu_scatter_nd_update_(
1018
+ self.v_buffer[layer_id - self.start_layer].view(
1019
+ -1, 1, self.qk_rope_head_dim
962
1020
  ),
963
- slot_indices=loc,
1021
+ loc.view(-1, 1),
1022
+ cache_v.view(-1, 1, self.qk_rope_head_dim),
964
1023
  )
965
1024
 
966
1025
 
@@ -465,6 +465,7 @@ class MHATokenToKVPoolHost(HostKVCache):
465
465
  raise ValueError(f"Unsupported layout: {self.layout}")
466
466
 
467
467
  def get_buffer_meta(self, keys, indices):
468
+ local_rank = get_tensor_model_parallel_rank()
468
469
  ptr_list = []
469
470
  key_list = []
470
471
  kv_buffer_data_ptr = self.kv_buffer.data_ptr()
@@ -488,8 +489,8 @@ class MHATokenToKVPoolHost(HostKVCache):
488
489
  ptr_list.append(k_ptr)
489
490
  ptr_list.append(v_ptr)
490
491
  key_ = keys[index // self.page_size]
491
- key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_k")
492
- key_list.append(f"{key_}_{get_tensor_model_parallel_rank()}_v")
492
+ key_list.append(f"{key_}_{local_rank}_k")
493
+ key_list.append(f"{key_}_{local_rank}_v")
493
494
  element_size = (
494
495
  self.layer_num
495
496
  * self.dtype.itemsize
@@ -152,6 +152,7 @@ class RadixCache(BasePrefixCache):
152
152
  self.root_node = TreeNode()
153
153
  self.root_node.key = []
154
154
  self.root_node.value = []
155
+ self.root_node.host_value = []
155
156
  self.root_node.lock_ref = 1
156
157
  self.evictable_size_ = 0
157
158
  self.protected_size_ = 0
@@ -11,12 +11,7 @@ from typing import Any, List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
14
- from sglang.srt.distributed import get_tensor_model_parallel_rank
15
- from sglang.srt.layers.dp_attention import (
16
- get_attention_tp_rank,
17
- is_dp_attention_enabled,
18
- )
19
- from sglang.srt.mem_cache.hicache_storage import HiCacheStorage
14
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
20
15
  from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
21
16
 
22
17
  logger = logging.getLogger(__name__)
@@ -172,19 +167,16 @@ class HiCacheHF3FS(HiCacheStorage):
172
167
 
173
168
  @staticmethod
174
169
  def from_env_config(
175
- bytes_per_page: int, dtype: torch.dtype, rank: int = None
170
+ bytes_per_page: int,
171
+ dtype: torch.dtype,
172
+ storage_config: HiCacheStorageConfig = None,
176
173
  ) -> "HiCacheHF3FS":
177
174
  from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
178
175
  Hf3fsGlobalMetadataClient,
179
176
  Hf3fsLocalMetadataClient,
180
177
  )
181
178
 
182
- if rank is None:
183
- rank = (
184
- get_attention_tp_rank()
185
- if is_dp_attention_enabled()
186
- else get_tensor_model_parallel_rank()
187
- )
179
+ rank = storage_config.tp_rank if storage_config is not None else 0
188
180
 
189
181
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
190
182
  if not config_path: