sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,9 @@ import logging
5
5
  import os
6
6
  import signal
7
7
  import threading
8
- from collections import OrderedDict
8
+ from abc import ABC, abstractmethod
9
9
  from functools import wraps
10
- from typing import List, Optional
10
+ from typing import List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
+ class Hf3fsMetadataInterface(ABC):
21
+ """Interface for HF3FS metadata operations."""
22
+
23
+ @abstractmethod
24
+ def initialize(self, rank: int, num_pages: int) -> None:
25
+ """Initialize the metadata service with specified number of pages."""
26
+ pass
27
+
28
+ @abstractmethod
29
+ def reserve_and_allocate_page_indices(
30
+ self,
31
+ rank: int,
32
+ keys: List[Tuple[str, str]],
33
+ ) -> List[Tuple[bool, int]]:
34
+ """
35
+ Reserve and allocate page indices for the specified keys.
36
+ Args:
37
+ rank: The rank of the process.
38
+ keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
39
+ Returns:
40
+ List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def confirm_write(
46
+ self,
47
+ rank: int,
48
+ written_keys_to_confirm: List[Tuple[str, int]],
49
+ pages_to_release: List[int],
50
+ ) -> None:
51
+ """
52
+ Confirm that key-value pairs have been successfully written to storage.
53
+ Args:
54
+ rank: The rank of the process.
55
+ written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
56
+ pages_to_release: A list of page indices to be released.
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
62
+ """
63
+ Get page indices for the specified keys.
64
+ Args:
65
+ rank: The rank of the process.
66
+ keys: A list of keys.
67
+ Returns:
68
+ List[Optional[int]]: A list of integers representing the page indices for the specified keys.
69
+ If a key is not found, the corresponding index will be None.
70
+ """
71
+ pass
72
+
73
+ @abstractmethod
74
+ def delete_keys(self, rank: int, keys: List[str]) -> None:
75
+ """Delete specified keys and their associated pages."""
76
+ pass
77
+
78
+ @abstractmethod
79
+ def exists(self, rank: int, keys: List[str]) -> List[bool]:
80
+ """Check if the specified keys exist."""
81
+ pass
82
+
83
+ @abstractmethod
84
+ def clear(self, rank: int) -> None:
85
+ """Clear all key-value pairs and page allocations for the specified rank."""
86
+ pass
87
+
88
+
20
89
  class AtomicCounter:
21
90
  def __init__(self, n: int):
22
91
  assert n > 0
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
48
117
 
49
118
  def __init__(
50
119
  self,
120
+ rank: int,
51
121
  file_path: str,
52
122
  file_size: int,
53
123
  numjobs: int,
54
124
  bytes_per_page: int,
55
125
  entries: int,
56
126
  dtype: torch.dtype,
127
+ metadata_client: Hf3fsMetadataInterface,
57
128
  ):
129
+ self.rank = rank
58
130
  self.file_path = file_path
59
131
  self.file_size = file_size
60
132
  self.numjobs = numjobs
61
133
  self.bytes_per_page = bytes_per_page
62
134
  self.entries = entries
63
135
  self.dtype = dtype
136
+ self.metadata_client = metadata_client
64
137
 
65
138
  self.numel = self.bytes_per_page // self.dtype.itemsize
66
-
67
139
  self.num_pages = self.file_size // self.bytes_per_page
68
140
 
69
141
  logger.info(
70
- "HiCacheHF3FS "
71
- f"file_path = {self.file_path}, "
72
- f"file_size = {self.file_size/(2**30):.2f} GB, "
73
- f"numjobs = {self.numjobs}, "
74
- f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
75
- f"entries = {self.entries}, "
76
- f"num_pages = {self.num_pages}"
142
+ f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
143
+ f"file_path={self.file_path}, "
144
+ f"file_size={self.file_size / (2 ** 30):.2f} GB, "
145
+ f"num_pages={self.num_pages}"
77
146
  )
78
147
 
79
148
  self.ac = AtomicCounter(self.numjobs)
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
84
153
  for _ in range(numjobs)
85
154
  ]
86
155
  self.executor = concurrent.futures.ThreadPoolExecutor(
87
- max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
156
+ max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
88
157
  )
89
158
 
90
- # Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
91
- # Future iterations may adopt a global KVCache manager to coordinate external cache instances
92
- # through centralized metadata orchestration.
159
+ self.metadata_client.initialize(self.rank, self.num_pages)
93
160
  self.lock = threading.RLock()
94
- self.free_pages = list(range(self.num_pages))
95
- self.key_to_index = OrderedDict()
96
161
 
97
162
  atexit.register(self.close)
98
163
 
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
104
169
  def from_env_config(
105
170
  rank: int, bytes_per_page: int, dtype: torch.dtype
106
171
  ) -> "HiCacheHF3FS":
172
+ from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
173
+ Hf3fsGlobalMetadataClient,
174
+ Hf3fsLocalMetadataClient,
175
+ )
176
+
107
177
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
108
178
  if not config_path:
109
179
  return HiCacheHF3FS(
180
+ rank=rank,
110
181
  file_path=f"/data/hicache.{rank}.bin",
111
182
  file_size=1 << 40,
112
183
  numjobs=16,
113
184
  bytes_per_page=bytes_per_page,
114
185
  entries=8,
115
186
  dtype=dtype,
187
+ metadata_client=Hf3fsLocalMetadataClient(),
116
188
  )
117
189
 
118
190
  try:
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
121
193
  except Exception as e:
122
194
  raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
123
195
 
196
+ # Check required keys (metadata_server_url is now optional)
124
197
  required_keys = {
125
198
  "file_path_prefix",
126
199
  "file_size",
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
131
204
  if missing_keys:
132
205
  raise ValueError(f"Missing required keys in config: {missing_keys}")
133
206
 
207
+ # Choose metadata client based on configuration
208
+ if "metadata_server_url" in config and config["metadata_server_url"]:
209
+ # Use global metadata client to connect to metadata server
210
+ metadata_server_url = config["metadata_server_url"]
211
+ metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
212
+ logger.info(
213
+ f"Using global metadata client with server url: {metadata_server_url}"
214
+ )
215
+ else:
216
+ # Use local metadata client for single-machine deployment
217
+ metadata_client = Hf3fsLocalMetadataClient()
218
+
134
219
  return HiCacheHF3FS(
220
+ rank=rank,
135
221
  file_path=f"{config['file_path_prefix']}.{rank}.bin",
136
222
  file_size=int(config["file_size"]),
137
223
  numjobs=int(config["numjobs"]),
138
224
  bytes_per_page=bytes_per_page,
139
225
  entries=int(config["entries"]),
140
226
  dtype=dtype,
227
+ metadata_client=metadata_client,
141
228
  )
142
229
 
143
230
  def get(
144
231
  self, key: str, target_location: Optional[torch.Tensor] = None
145
232
  ) -> torch.Tensor | None:
146
- return self.batch_get([key], target_location)[0]
233
+ return self.batch_get([key], [target_location] if target_location else None)[0]
147
234
 
148
235
  @synchronized()
149
236
  def batch_get(
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
151
238
  keys: List[str],
152
239
  target_locations: Optional[List[torch.Tensor]] = None,
153
240
  ) -> List[torch.Tensor | None]:
241
+ page_indices = self.metadata_client.get_page_indices(self.rank, keys)
242
+
154
243
  batch_indices, file_offsets = [], []
155
- for i, key in enumerate(keys):
156
- if key not in self.key_to_index:
157
- continue
158
- batch_indices.append(i)
159
- file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
160
- self.key_to_index.move_to_end(key)
161
- # TODO: target_locations
244
+ for i, page_index in enumerate(page_indices):
245
+ if page_index is not None:
246
+ batch_indices.append(i)
247
+ file_offsets.append(page_index * self.bytes_per_page)
248
+
162
249
  file_results = [
163
250
  torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
164
251
  ]
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
180
267
  if read_result == self.bytes_per_page:
181
268
  results[batch_index] = file_result
182
269
  else:
183
- logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
270
+ logger.error(
271
+ f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
272
+ )
184
273
 
185
274
  return results
186
275
 
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
188
277
  return self.batch_set([key], [value])
189
278
 
190
279
  def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
191
- indices = self.get_batch_set_indices(keys)
280
+ # Todo: Add prefix block's hash key
281
+ key_with_prefix = [(key, "") for key in keys]
282
+ indices = self.metadata_client.reserve_and_allocate_page_indices(
283
+ self.rank, key_with_prefix
284
+ )
285
+
192
286
  batch_indices, file_offsets, file_values = [], [], []
193
- for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
194
- if is_written or index == -1:
287
+ pages_to_release = []
288
+
289
+ for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
290
+ if is_written or page_index == -1:
195
291
  continue
292
+
196
293
  batch_indices.append(i)
197
- file_offsets.append(index * self.bytes_per_page)
294
+ file_offsets.append(page_index * self.bytes_per_page)
198
295
  file_values.append(value.contiguous())
199
296
 
200
297
  futures = [
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
211
308
  for result in future.result()
212
309
  ]
213
310
 
311
+ written_keys_to_confirm = []
214
312
  results = [index[0] for index in indices]
215
313
  for batch_index, write_result in zip(batch_indices, write_results):
216
314
  key = keys[batch_index]
217
- index = indices[batch_index][1]
315
+ page_index = indices[batch_index][1]
218
316
  if write_result:
219
- self.key_to_index[key] = index
220
- self.key_to_index.move_to_end(key)
317
+ written_keys_to_confirm.append((key, page_index))
221
318
  else:
222
- logger.error(f"HiCacheHF3FS set {key} failed")
223
- self.free_pages.append(index)
319
+ logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
320
+ pages_to_release.append(page_index)
224
321
  results[batch_index] = write_result
225
- return all(results)
226
-
227
- @synchronized()
228
- def get_batch_set_indices(self, keys: List[str]) -> list:
229
- ionum = len(keys)
230
- # results: tuples of (is_written: bool, page_idx: int)
231
- # - is_written: True = hit (no I/O), False = write (miss)
232
- # - page_idx: page storing data
233
- results = [None] * min(ionum, self.num_pages)
234
- if ionum > self.num_pages:
235
- results.extend([(False, -1)] * (ionum - self.num_pages))
236
-
237
- new_keys = []
238
- for batch_index, key in enumerate(keys[: self.num_pages]):
239
- if key in self.key_to_index:
240
- results[batch_index] = (True, self.key_to_index[key])
241
- self.key_to_index.move_to_end(key)
242
- else:
243
- new_keys.append((batch_index, key))
244
322
 
245
- for batch_index, _ in new_keys:
246
- index = (
247
- self.free_pages.pop()
248
- if len(self.free_pages) > 0
249
- else self.key_to_index.popitem(last=False)[1]
323
+ if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
324
+ self.metadata_client.confirm_write(
325
+ self.rank, written_keys_to_confirm, pages_to_release
250
326
  )
251
- results[batch_index] = (False, index)
252
327
 
253
- return results
328
+ return all(results)
254
329
 
255
330
  @synchronized()
256
331
  def delete(self, key: str) -> None:
257
- if key not in self.key_to_index:
258
- return
259
- index = self.key_to_index.pop(key)
260
- self.free_pages.append(index)
332
+ self.metadata_client.delete_keys(self.rank, [key])
261
333
 
262
334
  @synchronized()
263
335
  def exists(self, key: str) -> bool:
264
- return key in self.key_to_index
336
+ result = self.metadata_client.exists(self.rank, [key])
337
+ return result[0] if result else False
265
338
 
266
339
  @synchronized()
267
340
  def clear(self) -> None:
268
- self.free_pages = list(range(self.num_pages))
269
- self.key_to_index.clear()
341
+ self.metadata_client.clear(self.rank)
270
342
 
271
343
  def close(self) -> None:
272
344
  try:
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
- def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
21
+ def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
22
  local_rank = get_tensor_model_parallel_rank()
23
23
  prefix_str = ""
24
- if prefix_block_key:
25
- if len(prefix_block_key):
26
- prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
27
- current_token_ids_bytes = np.array(current_page_ids).tobytes()
24
+ if prior_hash:
25
+ prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
26
+ current_token_ids_bytes = np.array(token_ids).tobytes()
28
27
  current_hash_object = hashlib.sha256(current_token_ids_bytes)
29
28
  current_hash_hex = current_hash_object.hexdigest()
30
29
  return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
@@ -224,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
224
223
 
225
224
  def exists(self, keys) -> bool | dict:
226
225
  _keys = []
227
- local_rank = torch.cuda.current_device()
228
226
  for key in keys:
229
227
  if key is None:
230
228
  return None
231
- # Since mooncake store is stored in layer by layer,
232
- # only the first layer is checked here.
233
- _keys.append(f"{key}_{local_rank}_k")
229
+
230
+ _keys.append(f"{key}_k")
234
231
  result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
235
232
  return result
236
233
 
@@ -33,7 +33,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
33
  set_graph_pool_id,
34
34
  )
35
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
36
- from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
36
+ from sglang.srt.layers.dp_attention import (
37
+ DpPaddingMode,
38
+ get_attention_tp_rank,
39
+ get_attention_tp_size,
40
+ set_dp_buffer_len,
41
+ )
37
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
38
43
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
39
44
  from sglang.srt.model_executor.forward_batch_info import (
@@ -255,6 +260,9 @@ class CudaGraphRunner:
255
260
  self.dp_size = model_runner.server_args.dp_size
256
261
  self.pp_size = model_runner.server_args.pp_size
257
262
 
263
+ self.attn_tp_size = get_attention_tp_size()
264
+ self.attn_tp_rank = get_attention_tp_rank()
265
+
258
266
  # Batch sizes to capture
259
267
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
260
268
  rank0_log(f"Capture cuda graph bs {self.capture_bs}")
@@ -342,30 +350,15 @@ class CudaGraphRunner:
342
350
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
343
351
  (self.dp_size,), dtype=torch.int32
344
352
  )
345
- self.gathered_buffer = torch.zeros(
346
- (
347
- self.max_num_token * self.dp_size,
348
- self.model_runner.model_config.hidden_size,
349
- ),
350
- dtype=self.model_runner.dtype,
351
- )
352
353
  else:
353
354
  assert self.require_attn_tp_gather
354
355
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
355
356
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
356
357
  (1,), dtype=torch.int32
357
358
  )
358
- self.gathered_buffer = torch.zeros(
359
- (
360
- self.max_num_token,
361
- self.model_runner.model_config.hidden_size,
362
- ),
363
- dtype=self.model_runner.dtype,
364
- )
365
359
  else:
366
360
  self.global_num_tokens_gpu = None
367
361
  self.global_num_tokens_for_logprob_gpu = None
368
- self.gathered_buffer = None
369
362
 
370
363
  self.custom_mask = torch.ones(
371
364
  (
@@ -549,7 +542,7 @@ class CudaGraphRunner:
549
542
  device=input_ids.device,
550
543
  )
551
544
  )
552
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
545
+ global_dp_buffer_len = num_tokens * self.dp_size
553
546
  elif self.require_attn_tp_gather:
554
547
  self.global_num_tokens_gpu.copy_(
555
548
  torch.tensor(
@@ -565,9 +558,9 @@ class CudaGraphRunner:
565
558
  device=input_ids.device,
566
559
  )
567
560
  )
568
- gathered_buffer = self.gathered_buffer[:num_tokens]
561
+ global_dp_buffer_len = num_tokens
569
562
  else:
570
- gathered_buffer = None
563
+ global_dp_buffer_len = None
571
564
 
572
565
  spec_info = self.get_spec_info(num_tokens)
573
566
  if self.capture_hidden_mode != CaptureHiddenMode.FULL:
@@ -600,8 +593,8 @@ class CudaGraphRunner:
600
593
  positions=positions,
601
594
  global_num_tokens_gpu=self.global_num_tokens_gpu,
602
595
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
603
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
604
- gathered_buffer=gathered_buffer,
596
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
597
+ global_dp_buffer_len=global_dp_buffer_len,
605
598
  mrope_positions=mrope_positions,
606
599
  spec_algorithm=self.model_runner.spec_algorithm,
607
600
  spec_info=spec_info,
@@ -630,6 +623,7 @@ class CudaGraphRunner:
630
623
  def run_once():
631
624
  # Clean intermediate result cache for DP attention
632
625
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
626
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
633
627
 
634
628
  kwargs = {}
635
629
  if (
@@ -729,10 +723,12 @@ class CudaGraphRunner:
729
723
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
730
724
  self.positions[:raw_num_token].copy_(forward_batch.positions)
731
725
 
726
+ seq_lens_cpu = None
732
727
  if forward_batch.seq_lens_cpu is not None:
733
728
  if bs != raw_bs:
734
729
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)
735
730
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
731
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
736
732
 
737
733
  if pp_proxy_tensors:
738
734
  for key in self.pp_proxy_tensors.keys():
@@ -747,7 +743,17 @@ class CudaGraphRunner:
747
743
  self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
748
744
  self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
749
745
  if enable_num_token_non_padded(self.model_runner.server_args):
750
- self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
746
+ num_token_non_padded = forward_batch.num_token_non_padded
747
+ if self.require_gathered_buffer:
748
+ tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
749
+ num_local_token_non_padded = torch.clamp(
750
+ num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
751
+ min=0,
752
+ max=tokens_per_rank,
753
+ )
754
+ self.num_token_non_padded.copy_(num_local_token_non_padded)
755
+ else:
756
+ self.num_token_non_padded.copy_(num_token_non_padded)
751
757
  if self.enable_two_batch_overlap:
752
758
  self.tbo_plugin.replay_prepare(
753
759
  forward_mode=self.capture_forward_mode,
@@ -766,7 +772,7 @@ class CudaGraphRunner:
766
772
  self.encoder_lens[:bs] if self.is_encoder_decoder else None,
767
773
  self.capture_forward_mode,
768
774
  forward_batch.spec_info,
769
- seq_lens_cpu=self.seq_lens_cpu[:bs],
775
+ seq_lens_cpu=seq_lens_cpu,
770
776
  )
771
777
 
772
778
  # Store fields
@@ -40,9 +40,10 @@ import triton.language as tl
40
40
 
41
41
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
42
42
  from sglang.srt.layers.dp_attention import (
43
- DPPaddingMode,
43
+ DpPaddingMode,
44
44
  get_attention_dp_rank,
45
45
  get_attention_tp_size,
46
+ set_dp_buffer_len,
46
47
  )
47
48
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
48
49
  from sglang.srt.utils import (
@@ -274,13 +275,13 @@ class ForwardBatch:
274
275
  global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
275
276
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
276
277
  # The padding mode for DP attention
277
- dp_padding_mode: Optional[DPPaddingMode] = None
278
+ dp_padding_mode: Optional[DpPaddingMode] = None
278
279
  # for extend, local start pos and num tokens is different in logits processor
279
280
  # this will be computed in get_dp_local_info
280
281
  # this will be recomputed in LogitsMetadata.from_forward_batch
281
282
  dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
282
283
  dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
283
- gathered_buffer: Optional[torch.Tensor] = None
284
+ global_dp_buffer_len: Optional[int] = None
284
285
  is_extend_in_batch: bool = False
285
286
  can_run_dp_cuda_graph: bool = False
286
287
  global_forward_mode: Optional[ForwardMode] = None
@@ -628,7 +629,7 @@ class ForwardBatch:
628
629
  (global_num_tokens[i] - 1) // attn_tp_size + 1
629
630
  ) * attn_tp_size
630
631
 
631
- dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
632
+ dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
632
633
  self.dp_padding_mode = dp_padding_mode
633
634
 
634
635
  if dp_padding_mode.is_max_len():
@@ -642,23 +643,38 @@ class ForwardBatch:
642
643
  else:
643
644
  buffer_len = sum(global_num_tokens)
644
645
 
645
- self.gathered_buffer = torch.zeros(
646
- (buffer_len, model_runner.model_config.hidden_size),
647
- dtype=model_runner.dtype,
648
- device=model_runner.device,
649
- )
650
-
651
646
  if len(global_num_tokens) > 1:
652
647
  num_tokens = global_num_tokens[get_attention_dp_rank()]
653
648
  else:
654
649
  num_tokens = global_num_tokens[0]
655
650
 
656
- if self.forward_mode.is_decode():
657
- setattr(self, "raw_bs", self.batch_size)
658
- self.batch_size = num_tokens
651
+ self.global_dp_buffer_len = buffer_len
652
+ set_dp_buffer_len(buffer_len, num_tokens)
659
653
 
660
654
  bs = self.batch_size
661
655
 
656
+ if self.forward_mode.is_decode():
657
+ if self.is_extend_in_batch and dp_padding_mode.is_max_len():
658
+ setattr(self, "_original_forward_mode", self.forward_mode)
659
+ self.forward_mode = ForwardMode.EXTEND
660
+ self.extend_num_tokens = bs
661
+ self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
662
+ self.extend_prefix_lens = self.seq_lens - 1
663
+ self.extend_start_loc = torch.arange(
664
+ bs, dtype=torch.int32, device=self.seq_lens.device
665
+ )
666
+ self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
667
+ self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
668
+ self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
669
+ else:
670
+ setattr(self, "_original_batch_size", self.batch_size)
671
+ if self.spec_info is not None:
672
+ bs = self.batch_size = (
673
+ num_tokens // self.spec_info.num_tokens_per_batch
674
+ )
675
+ else:
676
+ bs = self.batch_size = num_tokens
677
+
662
678
  # padding
663
679
  self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
664
680
  self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -689,6 +705,7 @@ class ForwardBatch:
689
705
  if self.mrope_positions is not None:
690
706
  self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
691
707
 
708
+ # TODO: check if we need to pad other tensors
692
709
  if self.extend_seq_lens is not None:
693
710
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
694
711
 
@@ -712,7 +729,9 @@ class ForwardBatch:
712
729
 
713
730
  def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
714
731
 
715
- bs = getattr(self, "raw_bs", self.batch_size)
732
+ self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
733
+ self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
734
+ bs = self.batch_size
716
735
 
717
736
  if self.spec_info is not None:
718
737
  if self.forward_mode.is_decode(): # draft