sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/launch_lb.py +5 -20
  4. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  5. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  6. sglang/srt/layers/attention/utils.py +6 -1
  7. sglang/srt/layers/moe/ep_moe/layer.py +19 -34
  8. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
  9. sglang/srt/layers/quantization/fp8.py +52 -0
  10. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  11. sglang/srt/managers/cache_controller.py +35 -35
  12. sglang/srt/managers/scheduler.py +1 -0
  13. sglang/srt/mem_cache/hicache_storage.py +15 -6
  14. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  15. sglang/srt/mem_cache/memory_pool.py +15 -118
  16. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  17. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  18. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  19. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  20. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  21. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  22. sglang/srt/model_executor/model_runner.py +8 -1
  23. sglang/srt/model_loader/weight_utils.py +2 -0
  24. sglang/srt/models/deepseek_v2.py +5 -6
  25. sglang/srt/models/glm4_moe.py +3 -3
  26. sglang/srt/models/step3_vl.py +0 -3
  27. sglang/srt/server_args.py +40 -6
  28. sglang/srt/utils.py +1 -0
  29. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  30. sglang/version.py +1 -1
  31. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
  32. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
  33. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  34. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  35. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
+ import importlib.util
3
4
  import logging
4
5
  from enum import Enum
6
+ from functools import lru_cache
5
7
  from typing import List, Optional, Tuple
6
8
 
7
9
  import torch
10
+ from packaging import version as pkg_version
8
11
 
9
12
  from sglang.srt.distributed import (
10
13
  get_moe_expert_parallel_rank,
@@ -33,6 +36,15 @@ _is_cpu = is_cpu()
33
36
  logger = logging.getLogger(__name__)
34
37
 
35
38
 
39
+ @lru_cache(maxsize=1)
40
+ def should_use_flashinfer_trtllm_moe():
41
+ return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
42
+ not importlib.util.find_spec("flashinfer")
43
+ or pkg_version.parse(__import__("flashinfer").__version__)
44
+ >= pkg_version.parse("0.2.9rc1")
45
+ )
46
+
47
+
36
48
  class FusedMoeWeightScaleSupported(Enum):
37
49
  TENSOR = "tensor"
38
50
  CHANNEL = "channel"
@@ -119,7 +131,8 @@ class FusedMoE(torch.nn.Module):
119
131
  * self.num_local_experts : (self.moe_ep_rank + 1)
120
132
  * self.num_local_experts
121
133
  ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
122
- self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
134
+ if not self.enable_flashinfer_cutlass_moe:
135
+ self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
123
136
 
124
137
  self.routed_scaling_factor = routed_scaling_factor
125
138
  assert intermediate_size % self.moe_tp_size == 0
@@ -454,7 +467,7 @@ class FusedMoE(torch.nn.Module):
454
467
  )
455
468
 
456
469
  # Flashinfer assumes w31 format for w13_weight. Same for the scales.
457
- if getattr(self, "use_flashinfer_trtllm_moe", False):
470
+ if should_use_flashinfer_trtllm_moe():
458
471
  shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
459
472
 
460
473
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
@@ -686,3 +699,44 @@ class FusedMoE(torch.nn.Module):
686
699
  for expert_id in range(num_experts)
687
700
  for shard_id in ["w1", "w2", "w3"]
688
701
  ]
702
+
703
+
704
+ class FlashInferFusedMoE(FusedMoE):
705
+ def __init__(self, *args, **kwargs):
706
+ renormalize = kwargs.pop("renormalize", True)
707
+ num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
708
+ use_grouped_topk = kwargs.pop("use_grouped_topk", False)
709
+ num_expert_group = kwargs.pop("num_expert_group", None)
710
+ topk_group = kwargs.pop("topk_group", None)
711
+ correction_bias = kwargs.pop("correction_bias", None)
712
+ super().__init__(*args, **kwargs)
713
+ self.renormalize = renormalize
714
+ self.num_fused_shared_experts = num_fused_shared_experts
715
+ self.use_grouped_topk = use_grouped_topk
716
+ if self.use_grouped_topk:
717
+ assert num_expert_group is not None and topk_group is not None
718
+ self.num_expert_group = num_expert_group
719
+ self.topk_group = topk_group
720
+ self.correction_bias = correction_bias
721
+
722
+ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
723
+ assert self.quant_method is not None
724
+ assert (
725
+ self.renormalize
726
+ ), "Renormalize is required for flashinfer blockscale fp8 moe"
727
+ assert (
728
+ self.num_fused_shared_experts == 0
729
+ ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
730
+ # Matrix multiply.
731
+ final_hidden_states = self.quant_method.apply_with_router_logits(
732
+ layer=self,
733
+ x=hidden_states,
734
+ router_logits=router_logits,
735
+ activation=self.activation,
736
+ routed_scaling_factor=self.routed_scaling_factor,
737
+ )
738
+
739
+ if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
740
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
741
+
742
+ return final_hidden_states
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
72
72
  is_hip,
73
73
  is_npu,
74
74
  log_info_on_rank0,
75
+ next_power_of_2,
75
76
  print_warning_once,
76
77
  set_weight_attrs,
77
78
  use_intel_amx_backend,
@@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
490
491
  )
491
492
 
492
493
 
494
+ def get_tile_tokens_dim(num_tokens, top_k, num_experts):
495
+ # Guess tokens per expert assuming perfect expert distribution first.
496
+ num_tokens_per_expert = (num_tokens * top_k) // num_experts
497
+ # And pad the number to the next power of 2.
498
+ tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
499
+ # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
500
+ tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
501
+ return tile_tokens_dim
502
+
503
+
493
504
  class Fp8MoEMethod(FusedMoEMethodBase):
494
505
  """MoE method for FP8.
495
506
  Supports loading FP8 checkpoints with static weight scale and
@@ -1076,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
1076
1087
  routed_scaling_factor=routed_scaling_factor,
1077
1088
  )
1078
1089
 
1090
+ def apply_with_router_logits(
1091
+ self,
1092
+ layer: torch.nn.Module,
1093
+ x: torch.Tensor,
1094
+ router_logits: torch.Tensor,
1095
+ *,
1096
+ activation: str = "silu",
1097
+ routed_scaling_factor: Optional[float] = None,
1098
+ ) -> torch.Tensor:
1099
+ assert (
1100
+ activation == "silu"
1101
+ ), "Only silu is supported for flashinfer blockscale fp8 moe"
1102
+ a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
1103
+ # NOTE: scales of hidden states have to be transposed!
1104
+ a_sf_t = a_sf.t().contiguous()
1105
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
1106
+
1107
+ return trtllm_fp8_block_scale_moe(
1108
+ routing_logits=router_logits.to(torch.float32),
1109
+ routing_bias=layer.correction_bias.to(x.dtype),
1110
+ hidden_states=a_q,
1111
+ hidden_states_scale=a_sf_t,
1112
+ gemm1_weights=layer.w13_weight,
1113
+ gemm1_weights_scale=layer.w13_weight_scale_inv,
1114
+ gemm2_weights=layer.w2_weight,
1115
+ gemm2_weights_scale=layer.w2_weight_scale_inv,
1116
+ num_experts=layer.num_experts,
1117
+ top_k=layer.top_k,
1118
+ n_group=layer.num_expert_group,
1119
+ topk_group=layer.topk_group,
1120
+ intermediate_size=layer.w2_weight.shape[2],
1121
+ local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
1122
+ local_num_experts=layer.num_local_experts,
1123
+ routed_scaling_factor=routed_scaling_factor,
1124
+ tile_tokens_dim=get_tile_tokens_dim(
1125
+ x.shape[0], layer.top_k, layer.num_experts
1126
+ ),
1127
+ routing_method_type=2, # DeepSeek-styled routing method
1128
+ use_shuffled_weight=False,
1129
+ )
1130
+
1079
1131
  def maybe_apply_hip_fused_experts(
1080
1132
  self,
1081
1133
  layer: torch.nn.Module,
@@ -231,7 +231,10 @@ class W8A8Int8Config(QuantizationConfig):
231
231
 
232
232
  @classmethod
233
233
  def get_config_filenames(cls) -> List[str]:
234
- return []
234
+ filenames = []
235
+ if _is_npu:
236
+ filenames.append("quant_model_description.json")
237
+ return filenames
235
238
 
236
239
  @classmethod
237
240
  def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
@@ -25,12 +25,6 @@ if TYPE_CHECKING:
25
25
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
27
27
 
28
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
29
- from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
30
- MooncakeStore,
31
- get_hash_str_mooncake,
32
- )
33
- from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
34
28
 
35
29
  logger = logging.getLogger(__name__)
36
30
 
@@ -237,40 +231,35 @@ class HiCacheController:
237
231
  self.mem_pool_host = mem_pool_host
238
232
  self.write_policy = write_policy
239
233
  self.page_size = page_size
240
- # using kernel for small page KV cache transfer and DMA for large pages
241
- if not io_backend:
242
- IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
243
- self.io_backend = (
244
- "direct"
245
- if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
246
- else "kernel"
247
- )
248
- else:
249
- self.io_backend = io_backend
234
+ self.io_backend = io_backend
250
235
 
251
236
  self.enable_storage = False
252
237
  # todo: move backend initialization to storage backend module
253
238
  if storage_backend is not None:
254
- # create a new communication group for synchronizing storage operations across TP workers
255
- self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
256
- if self.tp_world_size > 1:
257
- group_ranks = torch.distributed.get_process_group_ranks(tp_group)
258
- self.prefetch_tp_group = torch.distributed.new_group(
259
- group_ranks, backend="gloo"
260
- )
261
- self.backup_tp_group = torch.distributed.new_group(
262
- group_ranks, backend="gloo"
263
- )
239
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
264
240
 
265
241
  if storage_backend == "file":
266
242
  self.storage_backend = HiCacheFile()
267
243
  self.get_hash_str = get_hash_str
244
+ elif storage_backend == "nixl":
245
+ from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
246
+
247
+ self.storage_backend = HiCacheNixl()
248
+ self.get_hash_str = get_hash_str
268
249
  elif storage_backend == "mooncake":
250
+ from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
251
+ MooncakeStore,
252
+ get_hash_str_mooncake,
253
+ )
254
+
269
255
  self.storage_backend = MooncakeStore()
270
256
  self.get_hash_str = get_hash_str_mooncake
271
257
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
272
258
  elif storage_backend == "hf3fs":
273
259
  from sglang.srt.distributed import get_tensor_model_parallel_rank
260
+ from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
261
+ HiCacheHF3FS,
262
+ )
274
263
 
275
264
  rank = get_tensor_model_parallel_rank()
276
265
  bytes_per_page = (
@@ -288,6 +277,16 @@ class HiCacheController:
288
277
  self.enable_storage = True
289
278
  # todo: threshold policy for prefetching
290
279
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
280
+ # create a new communication group for synchronizing storage operations across TP workers
281
+ self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
282
+ if self.tp_world_size > 1:
283
+ group_ranks = torch.distributed.get_process_group_ranks(tp_group)
284
+ self.prefetch_tp_group = torch.distributed.new_group(
285
+ group_ranks, backend="gloo"
286
+ )
287
+ self.backup_tp_group = torch.distributed.new_group(
288
+ group_ranks, backend="gloo"
289
+ )
291
290
 
292
291
  self.load_cache_event = load_cache_event
293
292
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -439,11 +438,8 @@ class HiCacheController:
439
438
  host_indices, device_indices = self.move_indices(
440
439
  operation.host_indices, operation.device_indices
441
440
  )
442
- self.mem_pool_device.backup_to_host_all_layer(
443
- self.mem_pool_host,
444
- host_indices,
445
- device_indices,
446
- self.io_backend,
441
+ self.mem_pool_host.backup_from_device_all_layer(
442
+ self.mem_pool_device, host_indices, device_indices, self.io_backend
447
443
  )
448
444
  self.write_stream.synchronize()
449
445
  self.mem_pool_host.complete_io(operation.host_indices)
@@ -483,8 +479,8 @@ class HiCacheController:
483
479
  batch_operation.host_indices, batch_operation.device_indices
484
480
  )
485
481
  for i in range(self.mem_pool_host.layer_num):
486
- self.mem_pool_device.load_from_host_per_layer(
487
- self.mem_pool_host,
482
+ self.mem_pool_host.load_to_device_per_layer(
483
+ self.mem_pool_device,
488
484
  host_indices,
489
485
  device_indices,
490
486
  i,
@@ -545,7 +541,11 @@ class HiCacheController:
545
541
  def generic_page_transfer(self, operation, batch_size=8):
546
542
  for i in range(0, len(operation.hash_value), batch_size):
547
543
  page_hashes = operation.hash_value[i : i + batch_size]
548
- page_data = self.storage_backend.batch_get(page_hashes)
544
+ # todo: zero copy
545
+ dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
546
+ page_hashes
547
+ )
548
+ page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
549
549
  if page_data is None:
550
550
  logger.warning(
551
551
  f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
@@ -679,7 +679,7 @@ class HiCacheController:
679
679
  for i in range(0, len(operation.hash_value), batch_size):
680
680
  page_hashes = operation.hash_value[i : i + batch_size]
681
681
  page_data = [
682
- self.mem_pool_host.get_flat_data_pages(
682
+ self.mem_pool_host.get_flat_data_page(
683
683
  operation.host_indices[j * self.page_size]
684
684
  )
685
685
  for j in range(i, i + len(page_hashes))
@@ -588,6 +588,7 @@ class Scheduler(
588
588
  == "fa3" # hot fix for incompatibility
589
589
  else server_args.hicache_io_backend
590
590
  ),
591
+ hicache_mem_layout=server_args.hicache_mem_layout,
591
592
  hicache_storage_backend=server_args.hicache_storage_backend,
592
593
  )
593
594
  self.tp_worker.register_hicache_layer_transfer_counter(
@@ -123,13 +123,22 @@ class HiCacheFile(HiCacheStorage):
123
123
  key = self._get_suffixed_key(key)
124
124
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
125
125
  try:
126
- # todo: fixing the target_location logic to enable in-place loading
127
- loaded_tensor = torch.load(tensor_path)
128
- if isinstance(loaded_tensor, torch.Tensor):
129
- return loaded_tensor
126
+ if target_location is not None:
127
+ # Load directly into target_location's memory buffer
128
+ with open(tensor_path, "rb") as f:
129
+ target_location.set_(
130
+ torch.frombuffer(f.read(), dtype=target_location.dtype)
131
+ .reshape(target_location.shape)
132
+ .storage()
133
+ )
134
+ return target_location
130
135
  else:
131
- logger.error(f"Loaded data for key {key} is not a tensor.")
132
- return None
136
+ loaded_tensor = torch.load(tensor_path)
137
+ if isinstance(loaded_tensor, torch.Tensor):
138
+ return loaded_tensor
139
+ else:
140
+ logger.error(f"Loaded data for key {key} is not a tensor.")
141
+ return None
133
142
  except FileNotFoundError:
134
143
  return None
135
144
 
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
35
35
  hicache_size: int,
36
36
  hicache_write_policy: str,
37
37
  hicache_io_backend: str,
38
+ hicache_mem_layout: str,
38
39
  hicache_storage_backend: Optional[str] = None,
39
40
  ):
41
+
42
+ if hicache_io_backend == "direct":
43
+ if hicache_mem_layout == "page_first":
44
+ hicache_mem_layout = "layer_first"
45
+ logger.warning(
46
+ "Page first layout is not supported with direct IO backend, switching to layer first layout"
47
+ )
48
+
40
49
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
41
50
  if isinstance(self.kv_cache, MHATokenToKVPool):
42
51
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
43
- self.kv_cache, hicache_ratio, hicache_size, page_size
52
+ self.kv_cache,
53
+ hicache_ratio,
54
+ hicache_size,
55
+ page_size,
56
+ hicache_mem_layout,
44
57
  )
45
58
  elif isinstance(self.kv_cache, MLATokenToKVPool):
46
59
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
47
- self.kv_cache, hicache_ratio, hicache_size, page_size
60
+ self.kv_cache,
61
+ hicache_ratio,
62
+ hicache_size,
63
+ page_size,
64
+ hicache_mem_layout,
48
65
  )
49
66
  else:
50
67
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -436,7 +453,7 @@ class HiRadixCache(RadixCache):
436
453
  last_host_node,
437
454
  fetched_token_ids,
438
455
  written_indices,
439
- hash_value[:min_completed_tokens],
456
+ hash_value[: min_completed_tokens // self.page_size],
440
457
  )
441
458
  if len(written_indices):
442
459
  self.cache_controller.mem_pool_host.update_prefetch(written_indices)
@@ -529,7 +546,7 @@ class HiRadixCache(RadixCache):
529
546
  prefix_len = self.key_match_fn(node.key, key)
530
547
  key = key[prefix_len:]
531
548
  host_value = host_value[prefix_len:]
532
- hash_value = hash_value[prefix_len:]
549
+ hash_value = hash_value[prefix_len // self.page_size :]
533
550
  matched_length += prefix_len
534
551
 
535
552
  if prefix_len < len(node.key):
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import torch
34
- import torch.distributed as dist
35
34
  import triton
36
35
  import triton.language as tl
37
36
 
38
37
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
39
+ from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
41
40
 
42
41
  logger = logging.getLogger(__name__)
43
42
 
44
43
  GB = 1024 * 1024 * 1024
45
44
  _is_cuda = is_cuda()
46
- _is_npu = is_npu()
47
- if not _is_npu:
48
- from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
49
45
 
50
46
 
51
47
  class ReqToTokenPool:
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
153
149
  ) -> None:
154
150
  raise NotImplementedError()
155
151
 
156
- @abc.abstractmethod
157
- def load_from_host_per_layer(
158
- self, host_pool, host_indices, device_indices, layer_id, io_backend
159
- ):
160
- raise NotImplementedError()
161
-
162
- @abc.abstractmethod
163
- def backup_to_host_all_layer(
164
- self, host_pool, host_indices, device_indices, io_backend
165
- ):
166
- raise NotImplementedError()
167
-
168
152
  def register_layer_transfer_counter(self, layer_transfer_counter):
169
153
  self.layer_transfer_counter = layer_transfer_counter
170
154
 
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
253
237
  )
254
238
  for _ in range(self.layer_num)
255
239
  ]
256
- self.token_stride = self.head_num * self.head_dim
257
- self.data_ptrs = torch.tensor(
258
- [x.data_ptr() for x in self.k_buffer + self.v_buffer],
240
+
241
+ self.k_data_ptrs = torch.tensor(
242
+ [x.data_ptr() for x in self.k_buffer],
243
+ dtype=torch.uint64,
244
+ device=self.device,
245
+ )
246
+ self.v_data_ptrs = torch.tensor(
247
+ [x.data_ptr() for x in self.v_buffer],
259
248
  dtype=torch.uint64,
260
249
  device=self.device,
261
250
  )
251
+ self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
262
252
  self.data_strides = torch.tensor(
263
253
  [
264
254
  np.prod(x.shape[1:]) * x.dtype.itemsize
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
347
337
  self.v_buffer[layer_id][chunk_indices] = v_chunk
348
338
  torch.cuda.synchronize()
349
339
 
350
- def load_from_host_per_layer(
351
- self,
352
- host_pool,
353
- host_indices,
354
- device_indices,
355
- layer_id,
356
- io_backend,
357
- ):
358
- transfer_kv_per_layer(
359
- src_k=host_pool.k_buffer[layer_id],
360
- dst_k=self.k_buffer[layer_id],
361
- src_v=host_pool.v_buffer[layer_id],
362
- dst_v=self.v_buffer[layer_id],
363
- src_indices=host_indices,
364
- dst_indices=device_indices,
365
- io_backend=io_backend,
366
- page_size=self.page_size,
367
- item_size=self.token_stride,
368
- )
369
-
370
- def backup_to_host_all_layer(
371
- self, host_pool, host_indices, device_indices, io_backend
372
- ):
373
- # todo: specialized all layer kernels for the layer-non-contiguous memory pool
374
- for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
375
- if layer_id - self.start_layer >= len(host_pool.k_buffer):
376
- raise ValueError(
377
- f"Layer ID {layer_id} exceeds the number of layers in host pool."
378
- )
379
- transfer_kv_per_layer(
380
- src_k=self.k_buffer[layer_id],
381
- dst_k=host_pool.k_buffer[layer_id],
382
- src_v=self.v_buffer[layer_id],
383
- dst_v=host_pool.v_buffer[layer_id],
384
- src_indices=device_indices,
385
- dst_indices=host_indices,
386
- io_backend=io_backend,
387
- page_size=self.page_size,
388
- item_size=self.token_stride,
389
- )
390
-
391
340
  def _get_key_buffer(self, layer_id: int):
392
341
  # for internal use of referencing
393
342
  if self.store_dtype != self.dtype:
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
602
551
  layer_id_override=layer_id_pool,
603
552
  )
604
553
 
605
- def load_from_host_per_layer(
606
- self, host_pool, host_indices, device_indices, layer_id, io_backend
607
- ):
608
- raise NotImplementedError("HiCache not supported for SWAKVPool.")
609
-
610
- def backup_to_host_all_layer(
611
- self, host_pool, host_indices, device_indices, io_backend
612
- ):
613
- raise NotImplementedError("HiCache not supported for SWAKVPool.")
614
-
615
554
 
616
555
  class AscendTokenToKVPool(MHATokenToKVPool):
617
556
 
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
823
762
  for _ in range(layer_num)
824
763
  ]
825
764
 
826
- self.token_stride = kv_lora_rank + qk_rope_head_dim
765
+ self.data_ptrs = torch.tensor(
766
+ [x.data_ptr() for x in self.kv_buffer],
767
+ dtype=torch.uint64,
768
+ device=self.device,
769
+ )
827
770
  self.layer_transfer_counter = None
828
771
 
829
772
  kv_size = self.get_kv_size_bytes()
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
909
852
  self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
910
853
  )
911
854
 
912
- def load_from_host_per_layer(
913
- self, host_pool, host_indices, device_indices, layer_id, io_backend
914
- ):
915
- transfer_kv_per_layer_mla(
916
- src=host_pool.kv_buffer[layer_id],
917
- dst=self.kv_buffer[layer_id],
918
- src_indices=host_indices,
919
- dst_indices=device_indices,
920
- io_backend=io_backend,
921
- page_size=self.page_size,
922
- item_size=self.token_stride,
923
- )
924
-
925
- def backup_to_host_all_layer(
926
- self, host_pool, host_indices, device_indices, io_backend
927
- ):
928
- # todo: specialized all layer kernels for the layer-non-contiguous memory pool
929
- for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
930
- if layer_id - self.start_layer >= len(host_pool.kv_buffer):
931
- raise ValueError(
932
- f"Layer ID {layer_id} exceeds the number of layers in host pool."
933
- )
934
- transfer_kv_per_layer_mla(
935
- src=self.kv_buffer[layer_id],
936
- dst=host_pool.kv_buffer[layer_id],
937
- src_indices=device_indices,
938
- dst_indices=host_indices,
939
- io_backend=io_backend,
940
- page_size=self.page_size,
941
- item_size=self.token_stride,
942
- )
943
-
944
855
  def get_cpu_copy(self, indices):
945
856
  torch.cuda.synchronize()
946
857
  kv_cache_cpu = []
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
1131
1042
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
1132
1043
  self.label_buffer[layer_id - self.start_layer][loc] = cache_label
1133
1044
 
1134
- def load_from_host_per_layer(
1135
- self, host_pool, host_indices, device_indices, layer_id, io_backend
1136
- ):
1137
- raise NotImplementedError(
1138
- "HiCache not supported for DoubleSparseTokenToKVPool."
1139
- )
1140
-
1141
- def backup_to_host_all_layer(
1142
- self, host_pool, host_indices, device_indices, io_backend
1143
- ):
1144
- raise NotImplementedError(
1145
- "HiCache not supported for DoubleSparseTokenToKVPool."
1146
- )
1147
-
1148
1045
 
1149
1046
  @triton.jit
1150
1047
  def copy_all_layer_kv_cache(