sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/ep_moe/layer.py +19 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
- sglang/srt/layers/quantization/fp8.py +52 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +35 -35
- sglang/srt/managers/scheduler.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +15 -6
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +8 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/step3_vl.py +0 -3
- sglang/srt/server_args.py +40 -6
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,13 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
+
import importlib.util
|
3
4
|
import logging
|
4
5
|
from enum import Enum
|
6
|
+
from functools import lru_cache
|
5
7
|
from typing import List, Optional, Tuple
|
6
8
|
|
7
9
|
import torch
|
10
|
+
from packaging import version as pkg_version
|
8
11
|
|
9
12
|
from sglang.srt.distributed import (
|
10
13
|
get_moe_expert_parallel_rank,
|
@@ -33,6 +36,15 @@ _is_cpu = is_cpu()
|
|
33
36
|
logger = logging.getLogger(__name__)
|
34
37
|
|
35
38
|
|
39
|
+
@lru_cache(maxsize=1)
|
40
|
+
def should_use_flashinfer_trtllm_moe():
|
41
|
+
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
42
|
+
not importlib.util.find_spec("flashinfer")
|
43
|
+
or pkg_version.parse(__import__("flashinfer").__version__)
|
44
|
+
>= pkg_version.parse("0.2.9rc1")
|
45
|
+
)
|
46
|
+
|
47
|
+
|
36
48
|
class FusedMoeWeightScaleSupported(Enum):
|
37
49
|
TENSOR = "tensor"
|
38
50
|
CHANNEL = "channel"
|
@@ -119,7 +131,8 @@ class FusedMoE(torch.nn.Module):
|
|
119
131
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
120
132
|
* self.num_local_experts
|
121
133
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
122
|
-
|
134
|
+
if not self.enable_flashinfer_cutlass_moe:
|
135
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
123
136
|
|
124
137
|
self.routed_scaling_factor = routed_scaling_factor
|
125
138
|
assert intermediate_size % self.moe_tp_size == 0
|
@@ -454,7 +467,7 @@ class FusedMoE(torch.nn.Module):
|
|
454
467
|
)
|
455
468
|
|
456
469
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
457
|
-
if
|
470
|
+
if should_use_flashinfer_trtllm_moe():
|
458
471
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
459
472
|
|
460
473
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
@@ -686,3 +699,44 @@ class FusedMoE(torch.nn.Module):
|
|
686
699
|
for expert_id in range(num_experts)
|
687
700
|
for shard_id in ["w1", "w2", "w3"]
|
688
701
|
]
|
702
|
+
|
703
|
+
|
704
|
+
class FlashInferFusedMoE(FusedMoE):
|
705
|
+
def __init__(self, *args, **kwargs):
|
706
|
+
renormalize = kwargs.pop("renormalize", True)
|
707
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
708
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
709
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
710
|
+
topk_group = kwargs.pop("topk_group", None)
|
711
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
712
|
+
super().__init__(*args, **kwargs)
|
713
|
+
self.renormalize = renormalize
|
714
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
715
|
+
self.use_grouped_topk = use_grouped_topk
|
716
|
+
if self.use_grouped_topk:
|
717
|
+
assert num_expert_group is not None and topk_group is not None
|
718
|
+
self.num_expert_group = num_expert_group
|
719
|
+
self.topk_group = topk_group
|
720
|
+
self.correction_bias = correction_bias
|
721
|
+
|
722
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
723
|
+
assert self.quant_method is not None
|
724
|
+
assert (
|
725
|
+
self.renormalize
|
726
|
+
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
727
|
+
assert (
|
728
|
+
self.num_fused_shared_experts == 0
|
729
|
+
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
730
|
+
# Matrix multiply.
|
731
|
+
final_hidden_states = self.quant_method.apply_with_router_logits(
|
732
|
+
layer=self,
|
733
|
+
x=hidden_states,
|
734
|
+
router_logits=router_logits,
|
735
|
+
activation=self.activation,
|
736
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
737
|
+
)
|
738
|
+
|
739
|
+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
740
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
741
|
+
|
742
|
+
return final_hidden_states
|
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
|
72
72
|
is_hip,
|
73
73
|
is_npu,
|
74
74
|
log_info_on_rank0,
|
75
|
+
next_power_of_2,
|
75
76
|
print_warning_once,
|
76
77
|
set_weight_attrs,
|
77
78
|
use_intel_amx_backend,
|
@@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
490
491
|
)
|
491
492
|
|
492
493
|
|
494
|
+
def get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
495
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
496
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
497
|
+
# And pad the number to the next power of 2.
|
498
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
499
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
500
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
501
|
+
return tile_tokens_dim
|
502
|
+
|
503
|
+
|
493
504
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
494
505
|
"""MoE method for FP8.
|
495
506
|
Supports loading FP8 checkpoints with static weight scale and
|
@@ -1076,6 +1087,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1076
1087
|
routed_scaling_factor=routed_scaling_factor,
|
1077
1088
|
)
|
1078
1089
|
|
1090
|
+
def apply_with_router_logits(
|
1091
|
+
self,
|
1092
|
+
layer: torch.nn.Module,
|
1093
|
+
x: torch.Tensor,
|
1094
|
+
router_logits: torch.Tensor,
|
1095
|
+
*,
|
1096
|
+
activation: str = "silu",
|
1097
|
+
routed_scaling_factor: Optional[float] = None,
|
1098
|
+
) -> torch.Tensor:
|
1099
|
+
assert (
|
1100
|
+
activation == "silu"
|
1101
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
1102
|
+
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
1103
|
+
# NOTE: scales of hidden states have to be transposed!
|
1104
|
+
a_sf_t = a_sf.t().contiguous()
|
1105
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1106
|
+
|
1107
|
+
return trtllm_fp8_block_scale_moe(
|
1108
|
+
routing_logits=router_logits.to(torch.float32),
|
1109
|
+
routing_bias=layer.correction_bias.to(x.dtype),
|
1110
|
+
hidden_states=a_q,
|
1111
|
+
hidden_states_scale=a_sf_t,
|
1112
|
+
gemm1_weights=layer.w13_weight,
|
1113
|
+
gemm1_weights_scale=layer.w13_weight_scale_inv,
|
1114
|
+
gemm2_weights=layer.w2_weight,
|
1115
|
+
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
1116
|
+
num_experts=layer.num_experts,
|
1117
|
+
top_k=layer.top_k,
|
1118
|
+
n_group=layer.num_expert_group,
|
1119
|
+
topk_group=layer.topk_group,
|
1120
|
+
intermediate_size=layer.w2_weight.shape[2],
|
1121
|
+
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
1122
|
+
local_num_experts=layer.num_local_experts,
|
1123
|
+
routed_scaling_factor=routed_scaling_factor,
|
1124
|
+
tile_tokens_dim=get_tile_tokens_dim(
|
1125
|
+
x.shape[0], layer.top_k, layer.num_experts
|
1126
|
+
),
|
1127
|
+
routing_method_type=2, # DeepSeek-styled routing method
|
1128
|
+
use_shuffled_weight=False,
|
1129
|
+
)
|
1130
|
+
|
1079
1131
|
def maybe_apply_hip_fused_experts(
|
1080
1132
|
self,
|
1081
1133
|
layer: torch.nn.Module,
|
@@ -231,7 +231,10 @@ class W8A8Int8Config(QuantizationConfig):
|
|
231
231
|
|
232
232
|
@classmethod
|
233
233
|
def get_config_filenames(cls) -> List[str]:
|
234
|
-
|
234
|
+
filenames = []
|
235
|
+
if _is_npu:
|
236
|
+
filenames.append("quant_model_description.json")
|
237
|
+
return filenames
|
235
238
|
|
236
239
|
@classmethod
|
237
240
|
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|
@@ -25,12 +25,6 @@ if TYPE_CHECKING:
|
|
25
25
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
26
26
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
27
27
|
|
28
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
29
|
-
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
30
|
-
MooncakeStore,
|
31
|
-
get_hash_str_mooncake,
|
32
|
-
)
|
33
|
-
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
34
28
|
|
35
29
|
logger = logging.getLogger(__name__)
|
36
30
|
|
@@ -237,40 +231,35 @@ class HiCacheController:
|
|
237
231
|
self.mem_pool_host = mem_pool_host
|
238
232
|
self.write_policy = write_policy
|
239
233
|
self.page_size = page_size
|
240
|
-
|
241
|
-
if not io_backend:
|
242
|
-
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
243
|
-
self.io_backend = (
|
244
|
-
"direct"
|
245
|
-
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
246
|
-
else "kernel"
|
247
|
-
)
|
248
|
-
else:
|
249
|
-
self.io_backend = io_backend
|
234
|
+
self.io_backend = io_backend
|
250
235
|
|
251
236
|
self.enable_storage = False
|
252
237
|
# todo: move backend initialization to storage backend module
|
253
238
|
if storage_backend is not None:
|
254
|
-
|
255
|
-
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
256
|
-
if self.tp_world_size > 1:
|
257
|
-
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
258
|
-
self.prefetch_tp_group = torch.distributed.new_group(
|
259
|
-
group_ranks, backend="gloo"
|
260
|
-
)
|
261
|
-
self.backup_tp_group = torch.distributed.new_group(
|
262
|
-
group_ranks, backend="gloo"
|
263
|
-
)
|
239
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
264
240
|
|
265
241
|
if storage_backend == "file":
|
266
242
|
self.storage_backend = HiCacheFile()
|
267
243
|
self.get_hash_str = get_hash_str
|
244
|
+
elif storage_backend == "nixl":
|
245
|
+
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
246
|
+
|
247
|
+
self.storage_backend = HiCacheNixl()
|
248
|
+
self.get_hash_str = get_hash_str
|
268
249
|
elif storage_backend == "mooncake":
|
250
|
+
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
251
|
+
MooncakeStore,
|
252
|
+
get_hash_str_mooncake,
|
253
|
+
)
|
254
|
+
|
269
255
|
self.storage_backend = MooncakeStore()
|
270
256
|
self.get_hash_str = get_hash_str_mooncake
|
271
257
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
272
258
|
elif storage_backend == "hf3fs":
|
273
259
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
260
|
+
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
261
|
+
HiCacheHF3FS,
|
262
|
+
)
|
274
263
|
|
275
264
|
rank = get_tensor_model_parallel_rank()
|
276
265
|
bytes_per_page = (
|
@@ -288,6 +277,16 @@ class HiCacheController:
|
|
288
277
|
self.enable_storage = True
|
289
278
|
# todo: threshold policy for prefetching
|
290
279
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
280
|
+
# create a new communication group for synchronizing storage operations across TP workers
|
281
|
+
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
282
|
+
if self.tp_world_size > 1:
|
283
|
+
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
284
|
+
self.prefetch_tp_group = torch.distributed.new_group(
|
285
|
+
group_ranks, backend="gloo"
|
286
|
+
)
|
287
|
+
self.backup_tp_group = torch.distributed.new_group(
|
288
|
+
group_ranks, backend="gloo"
|
289
|
+
)
|
291
290
|
|
292
291
|
self.load_cache_event = load_cache_event
|
293
292
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -439,11 +438,8 @@ class HiCacheController:
|
|
439
438
|
host_indices, device_indices = self.move_indices(
|
440
439
|
operation.host_indices, operation.device_indices
|
441
440
|
)
|
442
|
-
self.
|
443
|
-
self.
|
444
|
-
host_indices,
|
445
|
-
device_indices,
|
446
|
-
self.io_backend,
|
441
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
442
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
447
443
|
)
|
448
444
|
self.write_stream.synchronize()
|
449
445
|
self.mem_pool_host.complete_io(operation.host_indices)
|
@@ -483,8 +479,8 @@ class HiCacheController:
|
|
483
479
|
batch_operation.host_indices, batch_operation.device_indices
|
484
480
|
)
|
485
481
|
for i in range(self.mem_pool_host.layer_num):
|
486
|
-
self.
|
487
|
-
self.
|
482
|
+
self.mem_pool_host.load_to_device_per_layer(
|
483
|
+
self.mem_pool_device,
|
488
484
|
host_indices,
|
489
485
|
device_indices,
|
490
486
|
i,
|
@@ -545,7 +541,11 @@ class HiCacheController:
|
|
545
541
|
def generic_page_transfer(self, operation, batch_size=8):
|
546
542
|
for i in range(0, len(operation.hash_value), batch_size):
|
547
543
|
page_hashes = operation.hash_value[i : i + batch_size]
|
548
|
-
|
544
|
+
# todo: zero copy
|
545
|
+
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
546
|
+
page_hashes
|
547
|
+
)
|
548
|
+
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
549
549
|
if page_data is None:
|
550
550
|
logger.warning(
|
551
551
|
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
@@ -679,7 +679,7 @@ class HiCacheController:
|
|
679
679
|
for i in range(0, len(operation.hash_value), batch_size):
|
680
680
|
page_hashes = operation.hash_value[i : i + batch_size]
|
681
681
|
page_data = [
|
682
|
-
self.mem_pool_host.
|
682
|
+
self.mem_pool_host.get_flat_data_page(
|
683
683
|
operation.host_indices[j * self.page_size]
|
684
684
|
)
|
685
685
|
for j in range(i, i + len(page_hashes))
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -588,6 +588,7 @@ class Scheduler(
|
|
588
588
|
== "fa3" # hot fix for incompatibility
|
589
589
|
else server_args.hicache_io_backend
|
590
590
|
),
|
591
|
+
hicache_mem_layout=server_args.hicache_mem_layout,
|
591
592
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
592
593
|
)
|
593
594
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
@@ -123,13 +123,22 @@ class HiCacheFile(HiCacheStorage):
|
|
123
123
|
key = self._get_suffixed_key(key)
|
124
124
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
125
125
|
try:
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
126
|
+
if target_location is not None:
|
127
|
+
# Load directly into target_location's memory buffer
|
128
|
+
with open(tensor_path, "rb") as f:
|
129
|
+
target_location.set_(
|
130
|
+
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
131
|
+
.reshape(target_location.shape)
|
132
|
+
.storage()
|
133
|
+
)
|
134
|
+
return target_location
|
130
135
|
else:
|
131
|
-
|
132
|
-
|
136
|
+
loaded_tensor = torch.load(tensor_path)
|
137
|
+
if isinstance(loaded_tensor, torch.Tensor):
|
138
|
+
return loaded_tensor
|
139
|
+
else:
|
140
|
+
logger.error(f"Loaded data for key {key} is not a tensor.")
|
141
|
+
return None
|
133
142
|
except FileNotFoundError:
|
134
143
|
return None
|
135
144
|
|
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
|
|
35
35
|
hicache_size: int,
|
36
36
|
hicache_write_policy: str,
|
37
37
|
hicache_io_backend: str,
|
38
|
+
hicache_mem_layout: str,
|
38
39
|
hicache_storage_backend: Optional[str] = None,
|
39
40
|
):
|
41
|
+
|
42
|
+
if hicache_io_backend == "direct":
|
43
|
+
if hicache_mem_layout == "page_first":
|
44
|
+
hicache_mem_layout = "layer_first"
|
45
|
+
logger.warning(
|
46
|
+
"Page first layout is not supported with direct IO backend, switching to layer first layout"
|
47
|
+
)
|
48
|
+
|
40
49
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
41
50
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
42
51
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
43
|
-
self.kv_cache,
|
52
|
+
self.kv_cache,
|
53
|
+
hicache_ratio,
|
54
|
+
hicache_size,
|
55
|
+
page_size,
|
56
|
+
hicache_mem_layout,
|
44
57
|
)
|
45
58
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
46
59
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
47
|
-
self.kv_cache,
|
60
|
+
self.kv_cache,
|
61
|
+
hicache_ratio,
|
62
|
+
hicache_size,
|
63
|
+
page_size,
|
64
|
+
hicache_mem_layout,
|
48
65
|
)
|
49
66
|
else:
|
50
67
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
@@ -436,7 +453,7 @@ class HiRadixCache(RadixCache):
|
|
436
453
|
last_host_node,
|
437
454
|
fetched_token_ids,
|
438
455
|
written_indices,
|
439
|
-
hash_value[:min_completed_tokens],
|
456
|
+
hash_value[: min_completed_tokens // self.page_size],
|
440
457
|
)
|
441
458
|
if len(written_indices):
|
442
459
|
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
|
@@ -529,7 +546,7 @@ class HiRadixCache(RadixCache):
|
|
529
546
|
prefix_len = self.key_match_fn(node.key, key)
|
530
547
|
key = key[prefix_len:]
|
531
548
|
host_value = host_value[prefix_len:]
|
532
|
-
hash_value = hash_value[prefix_len:]
|
549
|
+
hash_value = hash_value[prefix_len // self.page_size :]
|
533
550
|
matched_length += prefix_len
|
534
551
|
|
535
552
|
if prefix_len < len(node.key):
|
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
|
31
31
|
|
32
32
|
import numpy as np
|
33
33
|
import torch
|
34
|
-
import torch.distributed as dist
|
35
34
|
import triton
|
36
35
|
import triton.language as tl
|
37
36
|
|
38
37
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
39
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.utils import get_bool_env_var, is_cuda,
|
39
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
41
40
|
|
42
41
|
logger = logging.getLogger(__name__)
|
43
42
|
|
44
43
|
GB = 1024 * 1024 * 1024
|
45
44
|
_is_cuda = is_cuda()
|
46
|
-
_is_npu = is_npu()
|
47
|
-
if not _is_npu:
|
48
|
-
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
49
45
|
|
50
46
|
|
51
47
|
class ReqToTokenPool:
|
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
|
|
153
149
|
) -> None:
|
154
150
|
raise NotImplementedError()
|
155
151
|
|
156
|
-
@abc.abstractmethod
|
157
|
-
def load_from_host_per_layer(
|
158
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
159
|
-
):
|
160
|
-
raise NotImplementedError()
|
161
|
-
|
162
|
-
@abc.abstractmethod
|
163
|
-
def backup_to_host_all_layer(
|
164
|
-
self, host_pool, host_indices, device_indices, io_backend
|
165
|
-
):
|
166
|
-
raise NotImplementedError()
|
167
|
-
|
168
152
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
169
153
|
self.layer_transfer_counter = layer_transfer_counter
|
170
154
|
|
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
|
|
253
237
|
)
|
254
238
|
for _ in range(self.layer_num)
|
255
239
|
]
|
256
|
-
|
257
|
-
self.
|
258
|
-
[x.data_ptr() for x in self.k_buffer
|
240
|
+
|
241
|
+
self.k_data_ptrs = torch.tensor(
|
242
|
+
[x.data_ptr() for x in self.k_buffer],
|
243
|
+
dtype=torch.uint64,
|
244
|
+
device=self.device,
|
245
|
+
)
|
246
|
+
self.v_data_ptrs = torch.tensor(
|
247
|
+
[x.data_ptr() for x in self.v_buffer],
|
259
248
|
dtype=torch.uint64,
|
260
249
|
device=self.device,
|
261
250
|
)
|
251
|
+
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
|
262
252
|
self.data_strides = torch.tensor(
|
263
253
|
[
|
264
254
|
np.prod(x.shape[1:]) * x.dtype.itemsize
|
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
|
|
347
337
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
348
338
|
torch.cuda.synchronize()
|
349
339
|
|
350
|
-
def load_from_host_per_layer(
|
351
|
-
self,
|
352
|
-
host_pool,
|
353
|
-
host_indices,
|
354
|
-
device_indices,
|
355
|
-
layer_id,
|
356
|
-
io_backend,
|
357
|
-
):
|
358
|
-
transfer_kv_per_layer(
|
359
|
-
src_k=host_pool.k_buffer[layer_id],
|
360
|
-
dst_k=self.k_buffer[layer_id],
|
361
|
-
src_v=host_pool.v_buffer[layer_id],
|
362
|
-
dst_v=self.v_buffer[layer_id],
|
363
|
-
src_indices=host_indices,
|
364
|
-
dst_indices=device_indices,
|
365
|
-
io_backend=io_backend,
|
366
|
-
page_size=self.page_size,
|
367
|
-
item_size=self.token_stride,
|
368
|
-
)
|
369
|
-
|
370
|
-
def backup_to_host_all_layer(
|
371
|
-
self, host_pool, host_indices, device_indices, io_backend
|
372
|
-
):
|
373
|
-
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
374
|
-
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
375
|
-
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
376
|
-
raise ValueError(
|
377
|
-
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
378
|
-
)
|
379
|
-
transfer_kv_per_layer(
|
380
|
-
src_k=self.k_buffer[layer_id],
|
381
|
-
dst_k=host_pool.k_buffer[layer_id],
|
382
|
-
src_v=self.v_buffer[layer_id],
|
383
|
-
dst_v=host_pool.v_buffer[layer_id],
|
384
|
-
src_indices=device_indices,
|
385
|
-
dst_indices=host_indices,
|
386
|
-
io_backend=io_backend,
|
387
|
-
page_size=self.page_size,
|
388
|
-
item_size=self.token_stride,
|
389
|
-
)
|
390
|
-
|
391
340
|
def _get_key_buffer(self, layer_id: int):
|
392
341
|
# for internal use of referencing
|
393
342
|
if self.store_dtype != self.dtype:
|
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
|
|
602
551
|
layer_id_override=layer_id_pool,
|
603
552
|
)
|
604
553
|
|
605
|
-
def load_from_host_per_layer(
|
606
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
607
|
-
):
|
608
|
-
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
609
|
-
|
610
|
-
def backup_to_host_all_layer(
|
611
|
-
self, host_pool, host_indices, device_indices, io_backend
|
612
|
-
):
|
613
|
-
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
614
|
-
|
615
554
|
|
616
555
|
class AscendTokenToKVPool(MHATokenToKVPool):
|
617
556
|
|
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
|
|
823
762
|
for _ in range(layer_num)
|
824
763
|
]
|
825
764
|
|
826
|
-
self.
|
765
|
+
self.data_ptrs = torch.tensor(
|
766
|
+
[x.data_ptr() for x in self.kv_buffer],
|
767
|
+
dtype=torch.uint64,
|
768
|
+
device=self.device,
|
769
|
+
)
|
827
770
|
self.layer_transfer_counter = None
|
828
771
|
|
829
772
|
kv_size = self.get_kv_size_bytes()
|
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
|
|
909
852
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
910
853
|
)
|
911
854
|
|
912
|
-
def load_from_host_per_layer(
|
913
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
914
|
-
):
|
915
|
-
transfer_kv_per_layer_mla(
|
916
|
-
src=host_pool.kv_buffer[layer_id],
|
917
|
-
dst=self.kv_buffer[layer_id],
|
918
|
-
src_indices=host_indices,
|
919
|
-
dst_indices=device_indices,
|
920
|
-
io_backend=io_backend,
|
921
|
-
page_size=self.page_size,
|
922
|
-
item_size=self.token_stride,
|
923
|
-
)
|
924
|
-
|
925
|
-
def backup_to_host_all_layer(
|
926
|
-
self, host_pool, host_indices, device_indices, io_backend
|
927
|
-
):
|
928
|
-
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
929
|
-
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
930
|
-
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
931
|
-
raise ValueError(
|
932
|
-
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
933
|
-
)
|
934
|
-
transfer_kv_per_layer_mla(
|
935
|
-
src=self.kv_buffer[layer_id],
|
936
|
-
dst=host_pool.kv_buffer[layer_id],
|
937
|
-
src_indices=device_indices,
|
938
|
-
dst_indices=host_indices,
|
939
|
-
io_backend=io_backend,
|
940
|
-
page_size=self.page_size,
|
941
|
-
item_size=self.token_stride,
|
942
|
-
)
|
943
|
-
|
944
855
|
def get_cpu_copy(self, indices):
|
945
856
|
torch.cuda.synchronize()
|
946
857
|
kv_cache_cpu = []
|
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
1131
1042
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
1132
1043
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
1133
1044
|
|
1134
|
-
def load_from_host_per_layer(
|
1135
|
-
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
1136
|
-
):
|
1137
|
-
raise NotImplementedError(
|
1138
|
-
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1139
|
-
)
|
1140
|
-
|
1141
|
-
def backup_to_host_all_layer(
|
1142
|
-
self, host_pool, host_indices, device_indices, io_backend
|
1143
|
-
):
|
1144
|
-
raise NotImplementedError(
|
1145
|
-
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1146
|
-
)
|
1147
|
-
|
1148
1045
|
|
1149
1046
|
@triton.jit
|
1150
1047
|
def copy_all_layer_kv_cache(
|