sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sglang/compile_deep_gemm.py +8 -1
  2. sglang/global_config.py +5 -1
  3. sglang/srt/conversation.py +0 -112
  4. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  5. sglang/srt/disaggregation/prefill.py +1 -0
  6. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  7. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  8. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  9. sglang/srt/distributed/parallel_state.py +11 -0
  10. sglang/srt/entrypoints/engine.py +4 -2
  11. sglang/srt/entrypoints/http_server.py +35 -15
  12. sglang/srt/eplb/expert_distribution.py +4 -2
  13. sglang/srt/hf_transformers_utils.py +25 -10
  14. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  15. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  16. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  17. sglang/srt/layers/attention/vision.py +27 -10
  18. sglang/srt/layers/communicator.py +14 -4
  19. sglang/srt/layers/linear.py +7 -1
  20. sglang/srt/layers/logits_processor.py +9 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +11 -35
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
  24. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  25. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  26. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  27. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  28. sglang/srt/layers/moe/utils.py +43 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  31. sglang/srt/layers/quantization/fp8.py +5 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  33. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  34. sglang/srt/lora/lora_registry.py +7 -0
  35. sglang/srt/managers/cache_controller.py +8 -4
  36. sglang/srt/managers/data_parallel_controller.py +52 -2
  37. sglang/srt/managers/io_struct.py +6 -1
  38. sglang/srt/managers/schedule_batch.py +3 -2
  39. sglang/srt/managers/schedule_policy.py +3 -1
  40. sglang/srt/managers/scheduler.py +144 -6
  41. sglang/srt/managers/template_manager.py +25 -22
  42. sglang/srt/managers/tokenizer_manager.py +114 -62
  43. sglang/srt/managers/utils.py +45 -1
  44. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  45. sglang/srt/mem_cache/hicache_storage.py +13 -21
  46. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  47. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  48. sglang/srt/model_executor/cuda_graph_runner.py +17 -3
  49. sglang/srt/model_executor/forward_batch_info.py +13 -3
  50. sglang/srt/model_executor/model_runner.py +5 -0
  51. sglang/srt/models/deepseek_v2.py +23 -17
  52. sglang/srt/models/glm4_moe.py +82 -19
  53. sglang/srt/models/grok.py +3 -3
  54. sglang/srt/models/llama4.py +13 -2
  55. sglang/srt/models/mixtral.py +3 -3
  56. sglang/srt/models/mllama4.py +428 -19
  57. sglang/srt/models/qwen2_moe.py +1 -4
  58. sglang/srt/models/qwen3_moe.py +7 -8
  59. sglang/srt/models/step3_vl.py +1 -1
  60. sglang/srt/multimodal/processors/base_processor.py +4 -3
  61. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  62. sglang/srt/operations_strategy.py +1 -1
  63. sglang/srt/server_args.py +80 -20
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  65. sglang/srt/two_batch_overlap.py +6 -4
  66. sglang/srt/utils.py +3 -24
  67. sglang/srt/weight_sync/utils.py +1 -1
  68. sglang/test/runners.py +2 -2
  69. sglang/test/test_utils.py +3 -3
  70. sglang/version.py +1 -1
  71. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  72. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
  73. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  74. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  75. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  76. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  77. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  78. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  79. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  80. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING, List, Optional, Tuple
5
+
6
+ import torch
7
+ from torch.utils.cpp_extension import load
8
+
9
+ _abs_path = os.path.dirname(os.path.abspath(__file__))
10
+ radix_tree_cpp = load(
11
+ name="radix_tree_cpp",
12
+ sources=[
13
+ f"{_abs_path}/tree_v2_binding.cpp",
14
+ f"{_abs_path}/tree_v2_debug.cpp",
15
+ f"{_abs_path}/tree_v2.cpp",
16
+ ],
17
+ extra_cflags=["-O3", "-std=c++20"],
18
+ )
19
+
20
+ if TYPE_CHECKING:
21
+
22
+ class TreeNodeCpp:
23
+ """
24
+ A placeholder for the TreeNode class. Cannot be constructed elsewhere.
25
+ """
26
+
27
+ class IOHandle:
28
+ """
29
+ A placeholder for the IOHandle class. Cannot be constructed elsewhere.
30
+ """
31
+
32
+ class RadixTreeCpp:
33
+ def __init__(
34
+ self,
35
+ disabled: bool,
36
+ host_size: Optional[int],
37
+ page_size: int,
38
+ write_through_threshold: int,
39
+ ):
40
+ """
41
+ Initializes the RadixTreeCpp instance.
42
+ Args:
43
+ disabled (bool): If True, the radix tree is disabled.
44
+ host_size (Optional[int]): Size of the radix tree on the CPU. None means no CPU tree.
45
+ page_size (int): Size of the page for the radix tree.
46
+ write_through_threshold (int): Threshold for writing through from GPU to CPU.
47
+ """
48
+ self.tree = radix_tree_cpp.RadixTree( # type: ignore
49
+ disabled, host_size, page_size, write_through_threshold
50
+ )
51
+
52
+ def match_prefix(
53
+ self, prefix: List[int]
54
+ ) -> Tuple[List[torch.Tensor], int, TreeNodeCpp, TreeNodeCpp]:
55
+ """
56
+ Matches a prefix in the radix tree.
57
+ Args:
58
+ prefix (List[int]): The prefix to match.
59
+ Returns:
60
+ Tuple[List[torch.Tensor], TreeNodeCpp, TreeNodeCpp]:
61
+ 0. A list of indices that is matched by the prefix on the GPU.
62
+ 1. Sum length of the indices matched on the CPU.
63
+ 2. The last node of the prefix matched on the GPU.
64
+ 3. The last node of the prefix matched on the CPU.
65
+ """
66
+ return self.tree.match_prefix(prefix)
67
+
68
+ def evict(self, num_tokens: int) -> List[torch.Tensor]:
69
+ """
70
+ Evicts a number of tokens from the radix tree.
71
+ Args:
72
+ num_tokens (int): The number of tokens to evict.
73
+ Returns:
74
+ List[torch.Tensor]: A list of indices that were evicted.
75
+ """
76
+ return self.tree.evict(num_tokens)
77
+
78
+ def lock_ref(self, handle: TreeNodeCpp, lock: bool) -> None:
79
+ """
80
+ Locks or unlocks a reference to a tree node.
81
+ After locking, the node will not be evicted from the radix tree.
82
+ Args:
83
+ handle (TreeNodeCpp): The tree node to lock or unlock.
84
+ lock (bool): If True, locks the node; if False, unlocks it.
85
+ """
86
+ return self.tree.lock_ref(handle, lock)
87
+
88
+ def writing_through(
89
+ self, key: List[int], indices: torch.Tensor
90
+ ) -> Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
91
+ """
92
+ Inserts a key-value pair into the radix tree and perform write-through check.
93
+ Args:
94
+ key (List[int]): The key to insert.
95
+ indices (torch.Tensor): The value associated with the key.
96
+ Returns:
97
+ Tuple[List[Tuple[IOHandle, torch.Tensor, torch.Tensor]], int]:
98
+ 0. A list of (IOHandle, device indices, host indices) tuples.
99
+ These IOhandles require write-through to the CPU in python side.
100
+ 1. The number of indices that are matched on device.
101
+ """
102
+ return self.tree.writing_through(key, indices)
103
+
104
+ def loading_onboard(
105
+ self,
106
+ host_node: TreeNodeCpp,
107
+ new_device_indices: torch.Tensor,
108
+ ) -> Tuple[IOHandle, List[torch.Tensor]]:
109
+ """
110
+ Updates the device indices of tree nodes within a range on the tree.
111
+ Args:
112
+ host_node (TreeNodeCpp): The tree node on the host, must be descendant of device_node.
113
+ new_device_indices (torch.Tensor): The new device indices to set.
114
+ The length of this tensor must be exactly host indices length.
115
+ Returns:
116
+ Tuple[IOHandle, List[torch.Tensor]]:
117
+ 0. An IOHandle that requires loading to the CPU in python side.
118
+ 1. A list of host indices corresponding to the new device indices.
119
+ """
120
+ return self.tree.loading_onboard(host_node, new_device_indices)
121
+
122
+ def commit_writing_through(self, handle: IOHandle, success: bool) -> None:
123
+ """
124
+ Commits the write-through process for a tree node.
125
+ Args:
126
+ handle (IOHandle): The IOHandle to commit.
127
+ success (bool): If True, commits the write-through; if False, just indicates failure.
128
+ """
129
+ return self.tree.commit_writing_through(handle, success)
130
+
131
+ def commit_loading_onboard(self, handle: IOHandle, success: bool) -> None:
132
+ """
133
+ Commits the load onboard process for tree nodes within a range on the tree.
134
+ Args:
135
+ handle (IOHandle): The IOHandle to commit.
136
+ success (bool): If True, commits the load-onboard; if False, just indicates failure.
137
+ """
138
+ return self.tree.commit_loading_onboard(handle, success)
139
+
140
+ def evictable_size(self) -> int:
141
+ """
142
+ Returns the size of the evictable part of the radix tree.
143
+ This is the size of the part that can be evicted from the GPU (ref_count = 0).
144
+ Returns:
145
+ int: The size of the evictable part.
146
+ """
147
+ return self.tree.evictable_size()
148
+
149
+ def protected_size(self) -> int:
150
+ """
151
+ Returns the size of the protected part of the radix tree.
152
+ This is the size of the part that cannot be evicted from the GPU (ref_count > 0).
153
+ Returns:
154
+ int: The size of the protected part.
155
+ """
156
+ return self.tree.protected_size()
157
+
158
+ def total_size(self) -> int:
159
+ """
160
+ Returns the total size of the radix tree (including CPU nodes).
161
+ Returns:
162
+ int: The total size of the radix tree.
163
+ """
164
+ return self.tree.total_size()
165
+
166
+ def reset(self) -> None:
167
+ """
168
+ Resets the radix tree, clearing all nodes and indices.
169
+ """
170
+ return self.tree.reset()
171
+
172
+ def debug_print(self) -> None:
173
+ """
174
+ Prints the internal state of the radix tree for debugging purposes.
175
+ """
176
+ return self.tree.debug_print()
177
+
178
+ else:
179
+ # Real implementation of the classes for runtime
180
+ RadixTreeCpp = radix_tree_cpp.RadixTree
181
+ TreeNodeCpp = object
182
+ IOHandle = object
@@ -33,8 +33,7 @@ class HiCacheStorage(ABC):
33
33
  It abstracts the underlying storage mechanism, allowing different implementations to be used.
34
34
  """
35
35
 
36
- # todo, translate tensor object access for different TP ranks
37
- # potentially pass model and TP configs into storage backend
36
+ # todo, potentially pass model and TP configs into storage backend
38
37
  # todo, the page size of storage backend does not have to be the same as the same as host memory pool
39
38
 
40
39
  @abstractmethod
@@ -117,35 +116,28 @@ class HiCacheFile(HiCacheStorage):
117
116
  def get(
118
117
  self,
119
118
  key: str,
120
- target_location: Optional[Any] = None,
119
+ target_location: torch.Tensor,
121
120
  target_sizes: Optional[Any] = None,
122
121
  ) -> torch.Tensor | None:
123
122
  key = self._get_suffixed_key(key)
124
123
  tensor_path = os.path.join(self.file_path, f"{key}.bin")
125
124
  try:
126
- if target_location is not None:
127
- # Load directly into target_location's memory buffer
128
- with open(tensor_path, "rb") as f:
129
- target_location.set_(
130
- torch.frombuffer(f.read(), dtype=target_location.dtype)
131
- .reshape(target_location.shape)
132
- .storage()
133
- )
134
- return target_location
135
- else:
136
- loaded_tensor = torch.load(tensor_path)
137
- if isinstance(loaded_tensor, torch.Tensor):
138
- return loaded_tensor
139
- else:
140
- logger.error(f"Loaded data for key {key} is not a tensor.")
141
- return None
125
+ # Load directly into target_location's memory buffer
126
+ with open(tensor_path, "rb") as f:
127
+ target_location.set_(
128
+ torch.frombuffer(f.read(), dtype=target_location.dtype)
129
+ .reshape(target_location.shape)
130
+ .untyped_storage()
131
+ )
132
+ return target_location
142
133
  except FileNotFoundError:
134
+ logger.warning(f"Failed to fetch {key} from HiCacheFile storage.")
143
135
  return None
144
136
 
145
137
  def batch_get(
146
138
  self,
147
139
  keys: List[str],
148
- target_locations: Optional[Any] = None,
140
+ target_locations: List[torch.Tensor],
149
141
  target_sizes: Optional[Any] = None,
150
142
  ) -> List[torch.Tensor | None]:
151
143
  return [
@@ -168,7 +160,7 @@ class HiCacheFile(HiCacheStorage):
168
160
  logger.debug(f"Key {key} already exists. Skipped.")
169
161
  return True
170
162
  try:
171
- torch.save(value, tensor_path)
163
+ value.contiguous().view(dtype=torch.uint8).numpy().tofile(tensor_path)
172
164
  return True
173
165
  except Exception as e:
174
166
  logger.error(f"Failed to save tensor {key}: {e}")
@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, List, Set
5
+
6
+ import torch
7
+
8
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
10
+ from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
11
+ IOHandle,
12
+ RadixTreeCpp,
13
+ TreeNodeCpp,
14
+ )
15
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
16
+
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.managers.schedule_batch import Req
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class RadixCacheCpp(BasePrefixCache):
25
+ def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
26
+ """
27
+ Merge a list of tensors into a single tensor.
28
+ Args:
29
+ l (List[torch.Tensor]): List of tensors to merge.
30
+ Returns:
31
+ torch.Tensor: Merged tensor.
32
+ """
33
+ if len(l) == 0:
34
+ return torch.empty(0, dtype=torch.int64, device=self.device)
35
+ elif len(l) == 1:
36
+ return l[0]
37
+ else:
38
+ return torch.cat(l)
39
+
40
+ def __init__(
41
+ self,
42
+ disable: bool,
43
+ use_hicache: bool,
44
+ req_to_token_pool: ReqToTokenPool,
45
+ token_to_kv_pool: BaseTokenToKVPoolAllocator,
46
+ tp_cache_group: torch.distributed.ProcessGroup,
47
+ page_size: int,
48
+ hicache_ratio: float,
49
+ hicache_size: int,
50
+ hicache_write_policy: str,
51
+ enable_kv_cache_events: bool = False,
52
+ hicache_oracle: bool = False,
53
+ enable_write_cancel: bool = False,
54
+ ):
55
+ self.disable = disable
56
+ self.enable_write_cancel = enable_write_cancel
57
+
58
+ assert (
59
+ enable_kv_cache_events is False
60
+ ), "HiRadixCache does not support kv cache events yet"
61
+ self.kv_cache = token_to_kv_pool.get_kvcache()
62
+
63
+ # record the nodes with ongoing write through
64
+ self.ongoing_write_through: Set[IOHandle] = set()
65
+ # record the node segments with ongoing load back
66
+ self.ongoing_load_back: Set[IOHandle] = set()
67
+ # todo: dynamically adjust the threshold
68
+ self.write_through_threshold = (
69
+ 1 if hicache_write_policy == "write_through" else 2
70
+ )
71
+ self.device = token_to_kv_pool.device
72
+ self.token_to_kv_pool = token_to_kv_pool
73
+ self.req_to_token_pool = req_to_token_pool
74
+ self.page_size = page_size
75
+
76
+ self.tp_group = tp_cache_group
77
+
78
+ if not use_hicache:
79
+ self.tree = RadixTreeCpp(
80
+ disabled=self.disable,
81
+ page_size=page_size,
82
+ host_size=None, # no host cache, this should be removed in the future
83
+ write_through_threshold=self.write_through_threshold,
84
+ )
85
+ self.cache_controller = None
86
+ return # early return if hicache is not used
87
+
88
+ raise NotImplementedError("Host cache is not supported yet")
89
+
90
+ def reset(self):
91
+ if self.cache_controller is not None:
92
+ # need to clear the acks before resetting the cache controller
93
+ raise NotImplementedError("Host cache is not supported yet")
94
+ self.tree.reset()
95
+
96
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
97
+ device_indices_vec, host_indices_length, node_gpu, node_cpu = (
98
+ self.tree.match_prefix(key)
99
+ )
100
+ return MatchResult(
101
+ device_indices=self._merge_tensor(device_indices_vec),
102
+ last_device_node=node_gpu,
103
+ last_host_node=node_cpu,
104
+ host_hit_length=host_indices_length,
105
+ )
106
+
107
+ def _insert(self, key: List[int], value: torch.Tensor) -> int:
108
+ """
109
+ Insert a key-value pair into the radix tree.
110
+ Args:
111
+ key (List[int]): The key to insert, represented as a list of integers.
112
+ value (torch.Tensor): The value to associate with the key.
113
+ Returns:
114
+ int: Number of device indices that were already present in the tree before the insertion.
115
+ """
116
+ ongoing_write, length = self.tree.writing_through(key, value)
117
+ if self.cache_controller is None:
118
+ assert len(ongoing_write) == 0, "Implementation error"
119
+ return length
120
+
121
+ raise NotImplementedError("Host cache is not supported yet")
122
+
123
+ def dec_lock_ref(self, node: TreeNodeCpp):
124
+ """
125
+ Decrement the reference count of a node to root of the radix tree.
126
+ Args:
127
+ node (TreeNodeCpp): The handle of the node to decrement the reference count for.
128
+ """
129
+ self.tree.lock_ref(node, False) # do not increment
130
+
131
+ def inc_lock_ref(self, node: TreeNodeCpp):
132
+ """
133
+ Increment the reference count of from a node to root of the radix tree.
134
+ Args:
135
+ node (TreeNodeCpp): The handle of the node to increment the reference count for.
136
+ """
137
+ self.tree.lock_ref(node, True)
138
+
139
+ def evict(self, num_tokens: int):
140
+ evicted_device_indices = self.tree.evict(num_tokens)
141
+ for indice in evicted_device_indices:
142
+ self.token_to_kv_pool.free(indice)
143
+
144
+ def evictable_size(self):
145
+ return self.tree.evictable_size()
146
+
147
+ def protected_size(self):
148
+ return self.tree.protected_size()
149
+
150
+ def total_size(self):
151
+ return self.tree.total_size()
152
+
153
+ def cache_finished_req(self, req: Req):
154
+ """Cache request when it finishes."""
155
+ assert req.req_pool_idx is not None
156
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
157
+ overall_len = len(token_ids) # prefill + decode
158
+ kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
159
+
160
+ # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
161
+ # it will automatically align them, but length of them should be equal
162
+ old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
163
+ new_prefix_len = self._insert(token_ids, kv_indices)
164
+
165
+ # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
166
+ assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
167
+
168
+ # KVCache between old & new is newly generated, but already exists in the pool
169
+ # we need to free this newly generated kv indices
170
+ if old_prefix_len < new_prefix_len:
171
+ self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
172
+
173
+ # need to free the unaligned part, since it cannot be inserted into the radix tree
174
+ if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
175
+ (unaligned_len := overall_len % self.page_size) > 0
176
+ ):
177
+ # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
178
+ self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
179
+
180
+ # Remove req slot release the cache lock
181
+ self.dec_lock_ref(req.last_node)
182
+ self.req_to_token_pool.free(req.req_pool_idx)
183
+
184
+ def cache_unfinished_req(self, req: Req):
185
+ """Cache request when it is unfinished."""
186
+ assert req.req_pool_idx is not None
187
+ token_ids = req.fill_ids
188
+ prefill_len = len(token_ids) # prefill only (maybe chunked)
189
+ kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
190
+
191
+ # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
192
+ # it will automatically align them, but length of them should be equal
193
+ old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
194
+ new_prefix_len = self._insert(token_ids, kv_indices)
195
+
196
+ # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
197
+ assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
198
+
199
+ # TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
200
+ # The prefix indices need to updated to reuse the kv indices in the pool
201
+ new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
202
+ new_indices = self._merge_tensor(new_indices_vec)
203
+ assert new_prefix_len <= len(new_indices)
204
+
205
+ # KVCache between old & new is newly generated, but already exists in the pool
206
+ # we need to free this newly generated kv indices and reuse the indices in the pool
207
+ if old_prefix_len < new_prefix_len:
208
+ self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
209
+ reused_indices = new_indices[old_prefix_len:new_prefix_len]
210
+ self.req_to_token_pool.req_to_token[
211
+ req.req_pool_idx, old_prefix_len:new_prefix_len
212
+ ] = reused_indices
213
+
214
+ if req.last_node != new_last_node:
215
+ self.dec_lock_ref(req.last_node)
216
+ self.inc_lock_ref(new_last_node)
217
+
218
+ # NOTE: there might be unaligned tail, so we may need to append it
219
+ assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
220
+ if self.page_size != 1 and len(new_indices) < prefill_len:
221
+ req.prefix_indices = torch.cat(
222
+ [new_indices, kv_indices[len(new_indices) :]]
223
+ )
224
+ else:
225
+ req.prefix_indices = new_indices
226
+ req.last_node = new_last_node
227
+
228
+ def pretty_print(self):
229
+ return self.tree.debug_print()
@@ -0,0 +1,35 @@
1
+ #include <torch/extension.h>
2
+
3
+ #include <cstring>
4
+ #include <vector>
5
+
6
+ void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
7
+ py::gil_scoped_release release;
8
+ char *src_ptr = static_cast<char *>(shm.data_ptr());
9
+ size_t current = 0;
10
+ for (size_t i = 0; i < dst.size(); ++i) {
11
+ auto &t = dst[i];
12
+ size_t t_bytes = t.numel() * t.element_size();
13
+ char *dst_ptr = static_cast<char *>(t.data_ptr());
14
+ std::memcpy(dst_ptr, src_ptr + current, t_bytes);
15
+ current += t_bytes;
16
+ }
17
+ }
18
+
19
+ void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
20
+ py::gil_scoped_release release;
21
+ char *dst_ptr = static_cast<char *>(shm.data_ptr());
22
+ size_t current = 0;
23
+ for (size_t i = 0; i < src.size(); ++i) {
24
+ auto &t = src[i];
25
+ size_t t_bytes = t.numel() * t.element_size();
26
+ char *src_ptr = static_cast<char *>(t.data_ptr());
27
+ std::memcpy(dst_ptr + current, src_ptr, t_bytes);
28
+ current += t_bytes;
29
+ }
30
+ }
31
+
32
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33
+ m.def("read_shm", &read_shm, "Read tensors from shared memory");
34
+ m.def("write_shm", &write_shm, "Write tensors to shared memory");
35
+ }
@@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
29
29
 
30
30
  from sglang.srt.custom_op import CustomOp
31
31
  from sglang.srt.distributed import get_tensor_model_parallel_rank
32
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
+ set_graph_pool_id,
34
+ )
32
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
33
36
  from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
34
37
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -372,6 +375,11 @@ class CudaGraphRunner:
372
375
  dtype=torch.bool,
373
376
  device="cuda",
374
377
  )
378
+ self.next_token_logits_buffer = torch.zeros(
379
+ (self.max_num_token, self.model_runner.model_config.vocab_size),
380
+ dtype=torch.float,
381
+ device="cuda",
382
+ )
375
383
 
376
384
  # Capture
377
385
  try:
@@ -517,6 +525,7 @@ class CudaGraphRunner:
517
525
  else:
518
526
  encoder_lens = None
519
527
  mrope_positions = self.mrope_positions[:, :bs]
528
+ next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
520
529
  self.num_token_non_padded[...] = num_tokens
521
530
 
522
531
  # pipeline parallelism
@@ -579,6 +588,7 @@ class CudaGraphRunner:
579
588
  input_ids=input_ids,
580
589
  req_pool_indices=req_pool_indices,
581
590
  seq_lens=seq_lens,
591
+ next_token_logits_buffer=next_token_logits_buffer,
582
592
  req_to_token_pool=self.model_runner.req_to_token_pool,
583
593
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
584
594
  attn_backend=self.model_runner.attn_backend,
@@ -643,11 +653,15 @@ class CudaGraphRunner:
643
653
 
644
654
  run_once()
645
655
 
646
- global global_graph_memory_pool
647
- with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
656
+ if get_global_graph_memory_pool() is None:
657
+ set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
658
+ # Set graph pool id globally to be able to use symmetric memory
659
+ set_graph_pool_id(get_global_graph_memory_pool())
660
+ with torch.cuda.graph(
661
+ graph, pool=get_global_graph_memory_pool(), stream=stream
662
+ ):
648
663
  out = run_once()
649
664
 
650
- global_graph_memory_pool = graph.pool()
651
665
  return graph, out
652
666
 
653
667
  def recapture_if_needed(self, forward_batch: ForwardBatch):
@@ -38,6 +38,7 @@ import torch
38
38
  import triton
39
39
  import triton.language as tl
40
40
 
41
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
41
42
  from sglang.srt.layers.dp_attention import (
42
43
  DPPaddingMode,
43
44
  get_attention_dp_rank,
@@ -188,6 +189,7 @@ class ForwardBatch:
188
189
  token_ids_logprobs: Optional[List[List[int]]] = None
189
190
 
190
191
  # For logits and logprobs post processing
192
+ next_token_logits_buffer: torch.Tensor = None
191
193
  temp_scaled_logprobs: bool = False
192
194
  temperature: torch.Tensor = None
193
195
  top_p_normalized_logprobs: bool = False
@@ -644,12 +646,17 @@ class ForwardBatch:
644
646
  device=model_runner.device,
645
647
  )
646
648
 
647
- bs = self.batch_size
648
649
  if len(global_num_tokens) > 1:
649
650
  num_tokens = global_num_tokens[get_attention_dp_rank()]
650
651
  else:
651
652
  num_tokens = global_num_tokens[0]
652
653
 
654
+ if self.forward_mode.is_decode():
655
+ setattr(self, "raw_bs", self.batch_size)
656
+ self.batch_size = num_tokens
657
+
658
+ bs = self.batch_size
659
+
653
660
  # padding
654
661
  self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
655
662
  self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -657,6 +664,9 @@ class ForwardBatch:
657
664
  seq_len_fill_value = (
658
665
  model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
659
666
  )
667
+ self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * (
668
+ bs - self.seq_lens.shape[0]
669
+ )
660
670
  self.seq_lens = self._pad_tensor_to_size(
661
671
  self.seq_lens, bs, value=seq_len_fill_value
662
672
  )
@@ -700,7 +710,7 @@ class ForwardBatch:
700
710
 
701
711
  def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
702
712
 
703
- bs = self.batch_size
713
+ bs = getattr(self, "raw_bs", self.batch_size)
704
714
 
705
715
  if self.spec_info is not None:
706
716
  if self.forward_mode.is_decode(): # draft
@@ -839,7 +849,7 @@ class ForwardBatch:
839
849
 
840
850
 
841
851
  def enable_num_token_non_padded(server_args):
842
- return server_args.enable_ep_moe or server_args.enable_deepep_moe
852
+ return get_moe_expert_parallel_world_size() > 1
843
853
 
844
854
 
845
855
  class PPProxyTensors:
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
60
60
  initialize_dp_attention,
61
61
  )
62
62
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
63
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
63
64
  from sglang.srt.layers.quantization import (
64
65
  deep_gemm_wrapper,
65
66
  monkey_patch_isinstance_for_vllm_base_layer,
@@ -217,6 +218,10 @@ class ModelRunner:
217
218
  "use_mla_backend": self.use_mla_backend,
218
219
  "speculative_algorithm": self.spec_algorithm,
219
220
  }
221
+ | {
222
+ "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
223
+ "deepep_mode": DeepEPMode(server_args.deepep_mode),
224
+ }
220
225
  )
221
226
 
222
227
  # CPU offload