sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +11 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +5 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +8 -4
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +144 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +17 -3
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/deepseek_v2.py +23 -17
- sglang/srt/models/glm4_moe.py +82 -19
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +80 -20
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +3 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch.utils.cpp_extension import load
|
8
|
+
|
9
|
+
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
10
|
+
radix_tree_cpp = load(
|
11
|
+
name="radix_tree_cpp",
|
12
|
+
sources=[
|
13
|
+
f"{_abs_path}/tree_v2_binding.cpp",
|
14
|
+
f"{_abs_path}/tree_v2_debug.cpp",
|
15
|
+
f"{_abs_path}/tree_v2.cpp",
|
16
|
+
],
|
17
|
+
extra_cflags=["-O3", "-std=c++20"],
|
18
|
+
)
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
|
22
|
+
class TreeNodeCpp:
|
23
|
+
"""
|
24
|
+
A placeholder for the TreeNode class. Cannot be constructed elsewhere.
|
25
|
+
"""
|
26
|
+
|
27
|
+
class IOHandle:
|
28
|
+
"""
|
29
|
+
A placeholder for the IOHandle class. Cannot be constructed elsewhere.
|
30
|
+
"""
|
31
|
+
|
32
|
+
class RadixTreeCpp:
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
disabled: bool,
|
36
|
+
host_size: Optional[int],
|
37
|
+
page_size: int,
|
38
|
+
write_through_threshold: int,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Initializes the RadixTreeCpp instance.
|
42
|
+
Args:
|
43
|
+
disabled (bool): If True, the radix tree is disabled.
|
44
|
+
host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
|
45
|
+
page_size (int): Size of the page for the radix tree.
|
46
|
+
write_through_threshold (int): Threshold for writing through from GPU to CPU.
|
47
|
+
"""
|
48
|
+
self.tree = radix_tree_cpp.RadixTree( # type: ignore
|
49
|
+
disabled, host_size, page_size, write_through_threshold
|
50
|
+
)
|
51
|
+
|
52
|
+
def match_prefix(
|
53
|
+
self, prefix: List[int]
|
54
|
+
) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
|
55
|
+
"""
|
56
|
+
Matches a prefix in the radix tree.
|
57
|
+
Args:
|
58
|
+
prefix (List[int]): The prefix to match.
|
59
|
+
Returns:
|
60
|
+
Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
|
61
|
+
0. A list of indices that is matched by the prefix on the GPU.
|
62
|
+
1. Sum length of the indices matched on the CPU.
|
63
|
+
2. The last node of the prefix matched on the GPU.
|
64
|
+
3. The last node of the prefix matched on the CPU.
|
65
|
+
"""
|
66
|
+
return self.tree.match_prefix(prefix)
|
67
|
+
|
68
|
+
def evict(self, num_tokens: int) -> List[torch.Tensor]:
|
69
|
+
"""
|
70
|
+
Evicts a number of tokens from the radix tree.
|
71
|
+
Args:
|
72
|
+
num_tokens (int): The number of tokens to evict.
|
73
|
+
Returns:
|
74
|
+
List[torch.Tensor]: A list of indices that were evicted.
|
75
|
+
"""
|
76
|
+
return self.tree.evict(num_tokens)
|
77
|
+
|
78
|
+
def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
|
79
|
+
"""
|
80
|
+
Locks or unlocks a reference to a tree node.
|
81
|
+
After locking, the node will not be evicted from the radix tree.
|
82
|
+
Args:
|
83
|
+
handle (TreeNodeCpp): The tree node to lock or unlock.
|
84
|
+
lock (bool): If True, locks the node; if False, unlocks it.
|
85
|
+
"""
|
86
|
+
return self.tree.lock_ref(handle, lock)
|
87
|
+
|
88
|
+
def writing_through(
|
89
|
+
self, key: List[int], indices: torch.Tensor
|
90
|
+
) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
91
|
+
"""
|
92
|
+
Inserts a key-value pair into the radix tree and perform write-through check.
|
93
|
+
Args:
|
94
|
+
key (List[int]): The key to insert.
|
95
|
+
indices (torch.Tensor): The value associated with the key.
|
96
|
+
Returns:
|
97
|
+
Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
|
98
|
+
0. A list of (IOHandle, device indices, host indices) tuples.
|
99
|
+
These IOhandles require write-through to the CPU in python side.
|
100
|
+
1. The number of indices that are matched on device.
|
101
|
+
"""
|
102
|
+
return self.tree.writing_through(key, indices)
|
103
|
+
|
104
|
+
def loading_onboard(
|
105
|
+
self,
|
106
|
+
host_node: TreeNodeCpp,
|
107
|
+
new_device_indices: torch.Tensor,
|
108
|
+
) -> Tuple[IOHandle, List[torch.Tensor]]:
|
109
|
+
"""
|
110
|
+
Updates the device indices of tree nodes within a range on the tree.
|
111
|
+
Args:
|
112
|
+
host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
|
113
|
+
new_device_indices (torch.Tensor): The new device indices to set.
|
114
|
+
The length of this tensor must be exactly host indices length.
|
115
|
+
Returns:
|
116
|
+
Tuple[IOHandle, List[torch.Tensor]]:
|
117
|
+
0. An IOHandle that requires loading to the CPU in python side.
|
118
|
+
1. A list of host indices corresponding to the new device indices.
|
119
|
+
"""
|
120
|
+
return self.tree.loading_onboard(host_node, new_device_indices)
|
121
|
+
|
122
|
+
def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
|
123
|
+
"""
|
124
|
+
Commits the write-through process for a tree node.
|
125
|
+
Args:
|
126
|
+
handle (IOHandle): The IOHandle to commit.
|
127
|
+
success (bool): If True, commits the write-through; if False, just indicates failure.
|
128
|
+
"""
|
129
|
+
return self.tree.commit_writing_through(handle, success)
|
130
|
+
|
131
|
+
def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
|
132
|
+
"""
|
133
|
+
Commits the load onboard process for tree nodes within a range on the tree.
|
134
|
+
Args:
|
135
|
+
handle (IOHandle): The IOHandle to commit.
|
136
|
+
success (bool): If True, commits the load-onboard; if False, just indicates failure.
|
137
|
+
"""
|
138
|
+
return self.tree.commit_loading_onboard(handle, success)
|
139
|
+
|
140
|
+
def evictable_size(self) -> int:
|
141
|
+
"""
|
142
|
+
Returns the size of the evictable part of the radix tree.
|
143
|
+
This is the size of the part that can be evicted from the GPU (ref_count = 0).
|
144
|
+
Returns:
|
145
|
+
int: The size of the evictable part.
|
146
|
+
"""
|
147
|
+
return self.tree.evictable_size()
|
148
|
+
|
149
|
+
def protected_size(self) -> int:
|
150
|
+
"""
|
151
|
+
Returns the size of the protected part of the radix tree.
|
152
|
+
This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
|
153
|
+
Returns:
|
154
|
+
int: The size of the protected part.
|
155
|
+
"""
|
156
|
+
return self.tree.protected_size()
|
157
|
+
|
158
|
+
def total_size(self) -> int:
|
159
|
+
"""
|
160
|
+
Returns the total size of the radix tree (including CPU nodes).
|
161
|
+
Returns:
|
162
|
+
int: The total size of the radix tree.
|
163
|
+
"""
|
164
|
+
return self.tree.total_size()
|
165
|
+
|
166
|
+
def reset(self) -> None:
|
167
|
+
"""
|
168
|
+
Resets the radix tree, clearing all nodes and indices.
|
169
|
+
"""
|
170
|
+
return self.tree.reset()
|
171
|
+
|
172
|
+
def debug_print(self) -> None:
|
173
|
+
"""
|
174
|
+
Prints the internal state of the radix tree for debugging purposes.
|
175
|
+
"""
|
176
|
+
return self.tree.debug_print()
|
177
|
+
|
178
|
+
else:
|
179
|
+
# Real implementation of the classes for runtime
|
180
|
+
RadixTreeCpp = radix_tree_cpp.RadixTree
|
181
|
+
TreeNodeCpp = object
|
182
|
+
IOHandle = object
|
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
|
|
33
33
|
It abstracts the underlying storage mechanism, allowing different implementations to be used.
|
34
34
|
"""
|
35
35
|
|
36
|
-
# todo,
|
37
|
-
# potentially pass model and TP configs into storage backend
|
36
|
+
# todo, potentially pass model and TP configs into storage backend
|
38
37
|
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
|
39
38
|
|
40
39
|
@abstractmethod
|
@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
|
|
117
116
|
def get(
|
118
117
|
self,
|
119
118
|
key: str,
|
120
|
-
target_location:
|
119
|
+
target_location: torch.Tensor,
|
121
120
|
target_sizes: Optional[Any] = None,
|
122
121
|
) -> torch.Tensor | None:
|
123
122
|
key = self._get_suffixed_key(key)
|
124
123
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
125
124
|
try:
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
target_location.
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
return target_location
|
135
|
-
else:
|
136
|
-
loaded_tensor = torch.load(tensor_path)
|
137
|
-
if isinstance(loaded_tensor, torch.Tensor):
|
138
|
-
return loaded_tensor
|
139
|
-
else:
|
140
|
-
logger.error(f"Loaded data for key {key} is not a tensor.")
|
141
|
-
return None
|
125
|
+
# Load directly into target_location's memory buffer
|
126
|
+
with open(tensor_path, "rb") as f:
|
127
|
+
target_location.set_(
|
128
|
+
torch.frombuffer(f.read(), dtype=target_location.dtype)
|
129
|
+
.reshape(target_location.shape)
|
130
|
+
.untyped_storage()
|
131
|
+
)
|
132
|
+
return target_location
|
142
133
|
except FileNotFoundError:
|
134
|
+
logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
|
143
135
|
return None
|
144
136
|
|
145
137
|
def batch_get(
|
146
138
|
self,
|
147
139
|
keys: List[str],
|
148
|
-
target_locations:
|
140
|
+
target_locations: List[torch.Tensor],
|
149
141
|
target_sizes: Optional[Any] = None,
|
150
142
|
) -> List[torch.Tensor | None]:
|
151
143
|
return [
|
@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
|
|
168
160
|
logger.debug(f"Key {key} already exists. Skipped.")
|
169
161
|
return True
|
170
162
|
try:
|
171
|
-
torch.
|
163
|
+
value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
|
172
164
|
return True
|
173
165
|
except Exception as e:
|
174
166
|
logger.error(f"Failed to save tensor {key}: {e}")
|
@@ -0,0 +1,229 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import TYPE_CHECKING, List, Set
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
9
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
10
|
+
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
11
|
+
IOHandle,
|
12
|
+
RadixTreeCpp,
|
13
|
+
TreeNodeCpp,
|
14
|
+
)
|
15
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.managers.schedule_batch import Req
|
19
|
+
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class RadixCacheCpp(BasePrefixCache):
|
25
|
+
def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
|
26
|
+
"""
|
27
|
+
Merge a list of tensors into a single tensor.
|
28
|
+
Args:
|
29
|
+
l (List[torch.Tensor]): List of tensors to merge.
|
30
|
+
Returns:
|
31
|
+
torch.Tensor: Merged tensor.
|
32
|
+
"""
|
33
|
+
if len(l) == 0:
|
34
|
+
return torch.empty(0, dtype=torch.int64, device=self.device)
|
35
|
+
elif len(l) == 1:
|
36
|
+
return l[0]
|
37
|
+
else:
|
38
|
+
return torch.cat(l)
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
disable: bool,
|
43
|
+
use_hicache: bool,
|
44
|
+
req_to_token_pool: ReqToTokenPool,
|
45
|
+
token_to_kv_pool: BaseTokenToKVPoolAllocator,
|
46
|
+
tp_cache_group: torch.distributed.ProcessGroup,
|
47
|
+
page_size: int,
|
48
|
+
hicache_ratio: float,
|
49
|
+
hicache_size: int,
|
50
|
+
hicache_write_policy: str,
|
51
|
+
enable_kv_cache_events: bool = False,
|
52
|
+
hicache_oracle: bool = False,
|
53
|
+
enable_write_cancel: bool = False,
|
54
|
+
):
|
55
|
+
self.disable = disable
|
56
|
+
self.enable_write_cancel = enable_write_cancel
|
57
|
+
|
58
|
+
assert (
|
59
|
+
enable_kv_cache_events is False
|
60
|
+
), "HiRadixCache does not support kv cache events yet"
|
61
|
+
self.kv_cache = token_to_kv_pool.get_kvcache()
|
62
|
+
|
63
|
+
# record the nodes with ongoing write through
|
64
|
+
self.ongoing_write_through: Set[IOHandle] = set()
|
65
|
+
# record the node segments with ongoing load back
|
66
|
+
self.ongoing_load_back: Set[IOHandle] = set()
|
67
|
+
# todo: dynamically adjust the threshold
|
68
|
+
self.write_through_threshold = (
|
69
|
+
1 if hicache_write_policy == "write_through" else 2
|
70
|
+
)
|
71
|
+
self.device = token_to_kv_pool.device
|
72
|
+
self.token_to_kv_pool = token_to_kv_pool
|
73
|
+
self.req_to_token_pool = req_to_token_pool
|
74
|
+
self.page_size = page_size
|
75
|
+
|
76
|
+
self.tp_group = tp_cache_group
|
77
|
+
|
78
|
+
if not use_hicache:
|
79
|
+
self.tree = RadixTreeCpp(
|
80
|
+
disabled=self.disable,
|
81
|
+
page_size=page_size,
|
82
|
+
host_size=None, # no host cache, this should be removed in the future
|
83
|
+
write_through_threshold=self.write_through_threshold,
|
84
|
+
)
|
85
|
+
self.cache_controller = None
|
86
|
+
return # early return if hicache is not used
|
87
|
+
|
88
|
+
raise NotImplementedError("Host cache is not supported yet")
|
89
|
+
|
90
|
+
def reset(self):
|
91
|
+
if self.cache_controller is not None:
|
92
|
+
# need to clear the acks before resetting the cache controller
|
93
|
+
raise NotImplementedError("Host cache is not supported yet")
|
94
|
+
self.tree.reset()
|
95
|
+
|
96
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
97
|
+
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
98
|
+
self.tree.match_prefix(key)
|
99
|
+
)
|
100
|
+
return MatchResult(
|
101
|
+
device_indices=self._merge_tensor(device_indices_vec),
|
102
|
+
last_device_node=node_gpu,
|
103
|
+
last_host_node=node_cpu,
|
104
|
+
host_hit_length=host_indices_length,
|
105
|
+
)
|
106
|
+
|
107
|
+
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
108
|
+
"""
|
109
|
+
Insert a key-value pair into the radix tree.
|
110
|
+
Args:
|
111
|
+
key (List[int]): The key to insert, represented as a list of integers.
|
112
|
+
value (torch.Tensor): The value to associate with the key.
|
113
|
+
Returns:
|
114
|
+
int: Number of device indices that were already present in the tree before the insertion.
|
115
|
+
"""
|
116
|
+
ongoing_write, length = self.tree.writing_through(key, value)
|
117
|
+
if self.cache_controller is None:
|
118
|
+
assert len(ongoing_write) == 0, "Implementation error"
|
119
|
+
return length
|
120
|
+
|
121
|
+
raise NotImplementedError("Host cache is not supported yet")
|
122
|
+
|
123
|
+
def dec_lock_ref(self, node: TreeNodeCpp):
|
124
|
+
"""
|
125
|
+
Decrement the reference count of a node to root of the radix tree.
|
126
|
+
Args:
|
127
|
+
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
|
128
|
+
"""
|
129
|
+
self.tree.lock_ref(node, False) # do not increment
|
130
|
+
|
131
|
+
def inc_lock_ref(self, node: TreeNodeCpp):
|
132
|
+
"""
|
133
|
+
Increment the reference count of from a node to root of the radix tree.
|
134
|
+
Args:
|
135
|
+
node (TreeNodeCpp): The handle of the node to increment the reference count for.
|
136
|
+
"""
|
137
|
+
self.tree.lock_ref(node, True)
|
138
|
+
|
139
|
+
def evict(self, num_tokens: int):
|
140
|
+
evicted_device_indices = self.tree.evict(num_tokens)
|
141
|
+
for indice in evicted_device_indices:
|
142
|
+
self.token_to_kv_pool.free(indice)
|
143
|
+
|
144
|
+
def evictable_size(self):
|
145
|
+
return self.tree.evictable_size()
|
146
|
+
|
147
|
+
def protected_size(self):
|
148
|
+
return self.tree.protected_size()
|
149
|
+
|
150
|
+
def total_size(self):
|
151
|
+
return self.tree.total_size()
|
152
|
+
|
153
|
+
def cache_finished_req(self, req: Req):
|
154
|
+
"""Cache request when it finishes."""
|
155
|
+
assert req.req_pool_idx is not None
|
156
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
157
|
+
overall_len = len(token_ids) # prefill + decode
|
158
|
+
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
159
|
+
|
160
|
+
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
161
|
+
# it will automatically align them, but length of them should be equal
|
162
|
+
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
163
|
+
new_prefix_len = self._insert(token_ids, kv_indices)
|
164
|
+
|
165
|
+
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
166
|
+
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
167
|
+
|
168
|
+
# KVCache between old & new is newly generated, but already exists in the pool
|
169
|
+
# we need to free this newly generated kv indices
|
170
|
+
if old_prefix_len < new_prefix_len:
|
171
|
+
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
172
|
+
|
173
|
+
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
174
|
+
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
175
|
+
(unaligned_len := overall_len % self.page_size) > 0
|
176
|
+
):
|
177
|
+
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
178
|
+
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
179
|
+
|
180
|
+
# Remove req slot release the cache lock
|
181
|
+
self.dec_lock_ref(req.last_node)
|
182
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
183
|
+
|
184
|
+
def cache_unfinished_req(self, req: Req):
|
185
|
+
"""Cache request when it is unfinished."""
|
186
|
+
assert req.req_pool_idx is not None
|
187
|
+
token_ids = req.fill_ids
|
188
|
+
prefill_len = len(token_ids) # prefill only (maybe chunked)
|
189
|
+
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
|
190
|
+
|
191
|
+
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
192
|
+
# it will automatically align them, but length of them should be equal
|
193
|
+
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
194
|
+
new_prefix_len = self._insert(token_ids, kv_indices)
|
195
|
+
|
196
|
+
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
197
|
+
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
198
|
+
|
199
|
+
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
200
|
+
# The prefix indices need to updated to reuse the kv indices in the pool
|
201
|
+
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
202
|
+
new_indices = self._merge_tensor(new_indices_vec)
|
203
|
+
assert new_prefix_len <= len(new_indices)
|
204
|
+
|
205
|
+
# KVCache between old & new is newly generated, but already exists in the pool
|
206
|
+
# we need to free this newly generated kv indices and reuse the indices in the pool
|
207
|
+
if old_prefix_len < new_prefix_len:
|
208
|
+
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
209
|
+
reused_indices = new_indices[old_prefix_len:new_prefix_len]
|
210
|
+
self.req_to_token_pool.req_to_token[
|
211
|
+
req.req_pool_idx, old_prefix_len:new_prefix_len
|
212
|
+
] = reused_indices
|
213
|
+
|
214
|
+
if req.last_node != new_last_node:
|
215
|
+
self.dec_lock_ref(req.last_node)
|
216
|
+
self.inc_lock_ref(new_last_node)
|
217
|
+
|
218
|
+
# NOTE: there might be unaligned tail, so we may need to append it
|
219
|
+
assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
|
220
|
+
if self.page_size != 1 and len(new_indices) < prefill_len:
|
221
|
+
req.prefix_indices = torch.cat(
|
222
|
+
[new_indices, kv_indices[len(new_indices) :]]
|
223
|
+
)
|
224
|
+
else:
|
225
|
+
req.prefix_indices = new_indices
|
226
|
+
req.last_node = new_last_node
|
227
|
+
|
228
|
+
def pretty_print(self):
|
229
|
+
return self.tree.debug_print()
|
@@ -0,0 +1,35 @@
|
|
1
|
+
#include <torch/extension.h>
|
2
|
+
|
3
|
+
#include <cstring>
|
4
|
+
#include <vector>
|
5
|
+
|
6
|
+
void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
|
7
|
+
py::gil_scoped_release release;
|
8
|
+
char *src_ptr = static_cast<char *>(shm.data_ptr());
|
9
|
+
size_t current = 0;
|
10
|
+
for (size_t i = 0; i < dst.size(); ++i) {
|
11
|
+
auto &t = dst[i];
|
12
|
+
size_t t_bytes = t.numel() * t.element_size();
|
13
|
+
char *dst_ptr = static_cast<char *>(t.data_ptr());
|
14
|
+
std::memcpy(dst_ptr, src_ptr + current, t_bytes);
|
15
|
+
current += t_bytes;
|
16
|
+
}
|
17
|
+
}
|
18
|
+
|
19
|
+
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
|
20
|
+
py::gil_scoped_release release;
|
21
|
+
char *dst_ptr = static_cast<char *>(shm.data_ptr());
|
22
|
+
size_t current = 0;
|
23
|
+
for (size_t i = 0; i < src.size(); ++i) {
|
24
|
+
auto &t = src[i];
|
25
|
+
size_t t_bytes = t.numel() * t.element_size();
|
26
|
+
char *src_ptr = static_cast<char *>(t.data_ptr());
|
27
|
+
std::memcpy(dst_ptr + current, src_ptr, t_bytes);
|
28
|
+
current += t_bytes;
|
29
|
+
}
|
30
|
+
}
|
31
|
+
|
32
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
33
|
+
m.def("read_shm", &read_shm, "Read tensors from shared memory");
|
34
|
+
m.def("write_shm", &write_shm, "Write tensors to shared memory");
|
35
|
+
}
|
@@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
|
29
29
|
|
30
30
|
from sglang.srt.custom_op import CustomOp
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
32
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
33
|
+
set_graph_pool_id,
|
34
|
+
)
|
32
35
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
33
36
|
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
34
37
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -372,6 +375,11 @@ class CudaGraphRunner:
|
|
372
375
|
dtype=torch.bool,
|
373
376
|
device="cuda",
|
374
377
|
)
|
378
|
+
self.next_token_logits_buffer = torch.zeros(
|
379
|
+
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
380
|
+
dtype=torch.float,
|
381
|
+
device="cuda",
|
382
|
+
)
|
375
383
|
|
376
384
|
# Capture
|
377
385
|
try:
|
@@ -517,6 +525,7 @@ class CudaGraphRunner:
|
|
517
525
|
else:
|
518
526
|
encoder_lens = None
|
519
527
|
mrope_positions = self.mrope_positions[:, :bs]
|
528
|
+
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
520
529
|
self.num_token_non_padded[...] = num_tokens
|
521
530
|
|
522
531
|
# pipeline parallelism
|
@@ -579,6 +588,7 @@ class CudaGraphRunner:
|
|
579
588
|
input_ids=input_ids,
|
580
589
|
req_pool_indices=req_pool_indices,
|
581
590
|
seq_lens=seq_lens,
|
591
|
+
next_token_logits_buffer=next_token_logits_buffer,
|
582
592
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
583
593
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
584
594
|
attn_backend=self.model_runner.attn_backend,
|
@@ -643,11 +653,15 @@ class CudaGraphRunner:
|
|
643
653
|
|
644
654
|
run_once()
|
645
655
|
|
646
|
-
|
647
|
-
|
656
|
+
if get_global_graph_memory_pool() is None:
|
657
|
+
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
658
|
+
# Set graph pool id globally to be able to use symmetric memory
|
659
|
+
set_graph_pool_id(get_global_graph_memory_pool())
|
660
|
+
with torch.cuda.graph(
|
661
|
+
graph, pool=get_global_graph_memory_pool(), stream=stream
|
662
|
+
):
|
648
663
|
out = run_once()
|
649
664
|
|
650
|
-
global_graph_memory_pool = graph.pool()
|
651
665
|
return graph, out
|
652
666
|
|
653
667
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
@@ -38,6 +38,7 @@ import torch
|
|
38
38
|
import triton
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
41
42
|
from sglang.srt.layers.dp_attention import (
|
42
43
|
DPPaddingMode,
|
43
44
|
get_attention_dp_rank,
|
@@ -188,6 +189,7 @@ class ForwardBatch:
|
|
188
189
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
189
190
|
|
190
191
|
# For logits and logprobs post processing
|
192
|
+
next_token_logits_buffer: torch.Tensor = None
|
191
193
|
temp_scaled_logprobs: bool = False
|
192
194
|
temperature: torch.Tensor = None
|
193
195
|
top_p_normalized_logprobs: bool = False
|
@@ -644,12 +646,17 @@ class ForwardBatch:
|
|
644
646
|
device=model_runner.device,
|
645
647
|
)
|
646
648
|
|
647
|
-
bs = self.batch_size
|
648
649
|
if len(global_num_tokens) > 1:
|
649
650
|
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
650
651
|
else:
|
651
652
|
num_tokens = global_num_tokens[0]
|
652
653
|
|
654
|
+
if self.forward_mode.is_decode():
|
655
|
+
setattr(self, "raw_bs", self.batch_size)
|
656
|
+
self.batch_size = num_tokens
|
657
|
+
|
658
|
+
bs = self.batch_size
|
659
|
+
|
653
660
|
# padding
|
654
661
|
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
655
662
|
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
@@ -657,6 +664,9 @@ class ForwardBatch:
|
|
657
664
|
seq_len_fill_value = (
|
658
665
|
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
659
666
|
)
|
667
|
+
self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
|
668
|
+
bs - self.seq_lens.shape[0]
|
669
|
+
)
|
660
670
|
self.seq_lens = self._pad_tensor_to_size(
|
661
671
|
self.seq_lens, bs, value=seq_len_fill_value
|
662
672
|
)
|
@@ -700,7 +710,7 @@ class ForwardBatch:
|
|
700
710
|
|
701
711
|
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
702
712
|
|
703
|
-
bs = self.batch_size
|
713
|
+
bs = getattr(self, "raw_bs", self.batch_size)
|
704
714
|
|
705
715
|
if self.spec_info is not None:
|
706
716
|
if self.forward_mode.is_decode(): # draft
|
@@ -839,7 +849,7 @@ class ForwardBatch:
|
|
839
849
|
|
840
850
|
|
841
851
|
def enable_num_token_non_padded(server_args):
|
842
|
-
return
|
852
|
+
return get_moe_expert_parallel_world_size() > 1
|
843
853
|
|
844
854
|
|
845
855
|
class PPProxyTensors:
|
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
|
|
60
60
|
initialize_dp_attention,
|
61
61
|
)
|
62
62
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
63
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
63
64
|
from sglang.srt.layers.quantization import (
|
64
65
|
deep_gemm_wrapper,
|
65
66
|
monkey_patch_isinstance_for_vllm_base_layer,
|
@@ -217,6 +218,10 @@ class ModelRunner:
|
|
217
218
|
"use_mla_backend": self.use_mla_backend,
|
218
219
|
"speculative_algorithm": self.spec_algorithm,
|
219
220
|
}
|
221
|
+
| {
|
222
|
+
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
223
|
+
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
224
|
+
}
|
220
225
|
)
|
221
226
|
|
222
227
|
# CPU offload
|