sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,9 @@ import logging
5
5
  import os
6
6
  import signal
7
7
  import threading
8
- from collections import OrderedDict
8
+ from abc import ABC, abstractmethod
9
9
  from functools import wraps
10
- from typing import List, Optional
10
+ from typing import List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
+ class Hf3fsMetadataInterface(ABC):
21
+ """Interface for HF3FS metadata operations."""
22
+
23
+ @abstractmethod
24
+ def initialize(self, rank: int, num_pages: int) -> None:
25
+ """Initialize the metadata service with specified number of pages."""
26
+ pass
27
+
28
+ @abstractmethod
29
+ def reserve_and_allocate_page_indices(
30
+ self,
31
+ rank: int,
32
+ keys: List[Tuple[str, str]],
33
+ ) -> List[Tuple[bool, int]]:
34
+ """
35
+ Reserve and allocate page indices for the specified keys.
36
+ Args:
37
+ rank: The rank of the process.
38
+ keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
39
+ Returns:
40
+ List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def confirm_write(
46
+ self,
47
+ rank: int,
48
+ written_keys_to_confirm: List[Tuple[str, int]],
49
+ pages_to_release: List[int],
50
+ ) -> None:
51
+ """
52
+ Confirm that key-value pairs have been successfully written to storage.
53
+ Args:
54
+ rank: The rank of the process.
55
+ written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
56
+ pages_to_release: A list of page indices to be released.
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
62
+ """
63
+ Get page indices for the specified keys.
64
+ Args:
65
+ rank: The rank of the process.
66
+ keys: A list of keys.
67
+ Returns:
68
+ List[Optional[int]]: A list of integers representing the page indices for the specified keys.
69
+ If a key is not found, the corresponding index will be None.
70
+ """
71
+ pass
72
+
73
+ @abstractmethod
74
+ def delete_keys(self, rank: int, keys: List[str]) -> None:
75
+ """Delete specified keys and their associated pages."""
76
+ pass
77
+
78
+ @abstractmethod
79
+ def exists(self, rank: int, keys: List[str]) -> List[bool]:
80
+ """Check if the specified keys exist."""
81
+ pass
82
+
83
+ @abstractmethod
84
+ def clear(self, rank: int) -> None:
85
+ """Clear all key-value pairs and page allocations for the specified rank."""
86
+ pass
87
+
88
+
20
89
  class AtomicCounter:
21
90
  def __init__(self, n: int):
22
91
  assert n > 0
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
48
117
 
49
118
  def __init__(
50
119
  self,
120
+ rank: int,
51
121
  file_path: str,
52
122
  file_size: int,
53
123
  numjobs: int,
54
124
  bytes_per_page: int,
55
125
  entries: int,
56
126
  dtype: torch.dtype,
127
+ metadata_client: Hf3fsMetadataInterface,
57
128
  ):
129
+ self.rank = rank
58
130
  self.file_path = file_path
59
131
  self.file_size = file_size
60
132
  self.numjobs = numjobs
61
133
  self.bytes_per_page = bytes_per_page
62
134
  self.entries = entries
63
135
  self.dtype = dtype
136
+ self.metadata_client = metadata_client
64
137
 
65
138
  self.numel = self.bytes_per_page // self.dtype.itemsize
66
-
67
139
  self.num_pages = self.file_size // self.bytes_per_page
68
140
 
69
141
  logger.info(
70
- "HiCacheHF3FS "
71
- f"file_path = {self.file_path}, "
72
- f"file_size = {self.file_size/(2**30):.2f} GB, "
73
- f"numjobs = {self.numjobs}, "
74
- f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
75
- f"entries = {self.entries}, "
76
- f"num_pages = {self.num_pages}"
142
+ f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
143
+ f"file_path={self.file_path}, "
144
+ f"file_size={self.file_size / (2 ** 30):.2f} GB, "
145
+ f"num_pages={self.num_pages}"
77
146
  )
78
147
 
79
148
  self.ac = AtomicCounter(self.numjobs)
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
84
153
  for _ in range(numjobs)
85
154
  ]
86
155
  self.executor = concurrent.futures.ThreadPoolExecutor(
87
- max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
156
+ max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
88
157
  )
89
158
 
90
- # Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
91
- # Future iterations may adopt a global KVCache manager to coordinate external cache instances
92
- # through centralized metadata orchestration.
159
+ self.metadata_client.initialize(self.rank, self.num_pages)
93
160
  self.lock = threading.RLock()
94
- self.free_pages = list(range(self.num_pages))
95
- self.key_to_index = OrderedDict()
96
161
 
97
162
  atexit.register(self.close)
98
163
 
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
104
169
  def from_env_config(
105
170
  rank: int, bytes_per_page: int, dtype: torch.dtype
106
171
  ) -> "HiCacheHF3FS":
172
+ from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
173
+ Hf3fsGlobalMetadataClient,
174
+ Hf3fsLocalMetadataClient,
175
+ )
176
+
107
177
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
108
178
  if not config_path:
109
179
  return HiCacheHF3FS(
180
+ rank=rank,
110
181
  file_path=f"/data/hicache.{rank}.bin",
111
182
  file_size=1 << 40,
112
183
  numjobs=16,
113
184
  bytes_per_page=bytes_per_page,
114
185
  entries=8,
115
186
  dtype=dtype,
187
+ metadata_client=Hf3fsLocalMetadataClient(),
116
188
  )
117
189
 
118
190
  try:
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
121
193
  except Exception as e:
122
194
  raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
123
195
 
196
+ # Check required keys (metadata_server_url is now optional)
124
197
  required_keys = {
125
198
  "file_path_prefix",
126
199
  "file_size",
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
131
204
  if missing_keys:
132
205
  raise ValueError(f"Missing required keys in config: {missing_keys}")
133
206
 
207
+ # Choose metadata client based on configuration
208
+ if "metadata_server_url" in config and config["metadata_server_url"]:
209
+ # Use global metadata client to connect to metadata server
210
+ metadata_server_url = config["metadata_server_url"]
211
+ metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
212
+ logger.info(
213
+ f"Using global metadata client with server url: {metadata_server_url}"
214
+ )
215
+ else:
216
+ # Use local metadata client for single-machine deployment
217
+ metadata_client = Hf3fsLocalMetadataClient()
218
+
134
219
  return HiCacheHF3FS(
220
+ rank=rank,
135
221
  file_path=f"{config['file_path_prefix']}.{rank}.bin",
136
222
  file_size=int(config["file_size"]),
137
223
  numjobs=int(config["numjobs"]),
138
224
  bytes_per_page=bytes_per_page,
139
225
  entries=int(config["entries"]),
140
226
  dtype=dtype,
227
+ metadata_client=metadata_client,
141
228
  )
142
229
 
143
230
  def get(
144
231
  self, key: str, target_location: Optional[torch.Tensor] = None
145
232
  ) -> torch.Tensor | None:
146
- return self.batch_get([key], target_location)[0]
233
+ return self.batch_get([key], [target_location] if target_location else None)[0]
147
234
 
148
235
  @synchronized()
149
236
  def batch_get(
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
151
238
  keys: List[str],
152
239
  target_locations: Optional[List[torch.Tensor]] = None,
153
240
  ) -> List[torch.Tensor | None]:
241
+ page_indices = self.metadata_client.get_page_indices(self.rank, keys)
242
+
154
243
  batch_indices, file_offsets = [], []
155
- for i, key in enumerate(keys):
156
- if key not in self.key_to_index:
157
- continue
158
- batch_indices.append(i)
159
- file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
160
- self.key_to_index.move_to_end(key)
161
- # TODO: target_locations
244
+ for i, page_index in enumerate(page_indices):
245
+ if page_index is not None:
246
+ batch_indices.append(i)
247
+ file_offsets.append(page_index * self.bytes_per_page)
248
+
162
249
  file_results = [
163
250
  torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
164
251
  ]
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
180
267
  if read_result == self.bytes_per_page:
181
268
  results[batch_index] = file_result
182
269
  else:
183
- logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
270
+ logger.error(
271
+ f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
272
+ )
184
273
 
185
274
  return results
186
275
 
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
188
277
  return self.batch_set([key], [value])
189
278
 
190
279
  def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
191
- indices = self.get_batch_set_indices(keys)
280
+ # Todo: Add prefix block's hash key
281
+ key_with_prefix = [(key, "") for key in keys]
282
+ indices = self.metadata_client.reserve_and_allocate_page_indices(
283
+ self.rank, key_with_prefix
284
+ )
285
+
192
286
  batch_indices, file_offsets, file_values = [], [], []
193
- for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
194
- if is_written or index == -1:
287
+ pages_to_release = []
288
+
289
+ for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
290
+ if is_written or page_index == -1:
195
291
  continue
292
+
196
293
  batch_indices.append(i)
197
- file_offsets.append(index * self.bytes_per_page)
294
+ file_offsets.append(page_index * self.bytes_per_page)
198
295
  file_values.append(value.contiguous())
199
296
 
200
297
  futures = [
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
211
308
  for result in future.result()
212
309
  ]
213
310
 
311
+ written_keys_to_confirm = []
214
312
  results = [index[0] for index in indices]
215
313
  for batch_index, write_result in zip(batch_indices, write_results):
216
314
  key = keys[batch_index]
217
- index = indices[batch_index][1]
315
+ page_index = indices[batch_index][1]
218
316
  if write_result:
219
- self.key_to_index[key] = index
220
- self.key_to_index.move_to_end(key)
317
+ written_keys_to_confirm.append((key, page_index))
221
318
  else:
222
- logger.error(f"HiCacheHF3FS set {key} failed")
223
- self.free_pages.append(index)
319
+ logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
320
+ pages_to_release.append(page_index)
224
321
  results[batch_index] = write_result
225
- return all(results)
226
-
227
- @synchronized()
228
- def get_batch_set_indices(self, keys: List[str]) -> list:
229
- ionum = len(keys)
230
- # results: tuples of (is_written: bool, page_idx: int)
231
- # - is_written: True = hit (no I/O), False = write (miss)
232
- # - page_idx: page storing data
233
- results = [None] * min(ionum, self.num_pages)
234
- if ionum > self.num_pages:
235
- results.extend([(False, -1)] * (ionum - self.num_pages))
236
-
237
- new_keys = []
238
- for batch_index, key in enumerate(keys[: self.num_pages]):
239
- if key in self.key_to_index:
240
- results[batch_index] = (True, self.key_to_index[key])
241
- self.key_to_index.move_to_end(key)
242
- else:
243
- new_keys.append((batch_index, key))
244
322
 
245
- for batch_index, _ in new_keys:
246
- index = (
247
- self.free_pages.pop()
248
- if len(self.free_pages) > 0
249
- else self.key_to_index.popitem(last=False)[1]
323
+ if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
324
+ self.metadata_client.confirm_write(
325
+ self.rank, written_keys_to_confirm, pages_to_release
250
326
  )
251
- results[batch_index] = (False, index)
252
327
 
253
- return results
328
+ return all(results)
254
329
 
255
330
  @synchronized()
256
331
  def delete(self, key: str) -> None:
257
- if key not in self.key_to_index:
258
- return
259
- index = self.key_to_index.pop(key)
260
- self.free_pages.append(index)
332
+ self.metadata_client.delete_keys(self.rank, [key])
261
333
 
262
334
  @synchronized()
263
335
  def exists(self, key: str) -> bool:
264
- return key in self.key_to_index
336
+ result = self.metadata_client.exists(self.rank, [key])
337
+ return result[0] if result else False
265
338
 
266
339
  @synchronized()
267
340
  def clear(self) -> None:
268
- self.free_pages = list(range(self.num_pages))
269
- self.key_to_index.clear()
341
+ self.metadata_client.clear(self.rank)
270
342
 
271
343
  def close(self) -> None:
272
344
  try:
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
- def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
21
+ def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
22
  local_rank = get_tensor_model_parallel_rank()
23
23
  prefix_str = ""
24
- if prefix_block_key:
25
- if len(prefix_block_key):
26
- prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
27
- current_token_ids_bytes = np.array(current_page_ids).tobytes()
24
+ if prior_hash:
25
+ prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
26
+ current_token_ids_bytes = np.array(token_ids).tobytes()
28
27
  current_hash_object = hashlib.sha256(current_token_ids_bytes)
29
28
  current_hash_hex = current_hash_object.hexdigest()
30
29
  return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
@@ -224,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
224
223
 
225
224
  def exists(self, keys) -> bool | dict:
226
225
  _keys = []
227
- local_rank = torch.cuda.current_device()
228
226
  for key in keys:
229
227
  if key is None:
230
228
  return None
231
- # Since mooncake store is stored in layer by layer,
232
- # only the first layer is checked here.
233
- _keys.append(f"{key}_{local_rank}_k")
229
+
230
+ _keys.append(f"{key}_k")
234
231
  result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
235
232
  return result
236
233
 
@@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
33
  set_graph_pool_id,
34
34
  )
35
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
36
- from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
36
+ from sglang.srt.layers.dp_attention import (
37
+ DPPaddingMode,
38
+ get_attention_tp_rank,
39
+ get_attention_tp_size,
40
+ )
37
41
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
38
42
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
39
43
  from sglang.srt.model_executor.forward_batch_info import (
@@ -255,6 +259,9 @@ class CudaGraphRunner:
255
259
  self.dp_size = model_runner.server_args.dp_size
256
260
  self.pp_size = model_runner.server_args.pp_size
257
261
 
262
+ self.attn_tp_size = get_attention_tp_size()
263
+ self.attn_tp_rank = get_attention_tp_rank()
264
+
258
265
  # Batch sizes to capture
259
266
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
260
267
  rank0_log(f"Capture cuda graph bs {self.capture_bs}")
@@ -576,11 +583,11 @@ class CudaGraphRunner:
576
583
  )
577
584
 
578
585
  if self.model_runner.server_args.enable_lora:
579
- # It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
580
- # `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
581
- lora_paths = [None] * bs
586
+ # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
587
+ # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
588
+ lora_ids = [None] * bs
582
589
  else:
583
- lora_paths = None
590
+ lora_ids = None
584
591
 
585
592
  forward_batch = ForwardBatch(
586
593
  forward_mode=self.capture_forward_mode,
@@ -589,6 +596,7 @@ class CudaGraphRunner:
589
596
  req_pool_indices=req_pool_indices,
590
597
  seq_lens=seq_lens,
591
598
  next_token_logits_buffer=next_token_logits_buffer,
599
+ orig_seq_lens=seq_lens,
592
600
  req_to_token_pool=self.model_runner.req_to_token_pool,
593
601
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
594
602
  attn_backend=self.model_runner.attn_backend,
@@ -607,11 +615,11 @@ class CudaGraphRunner:
607
615
  capture_hidden_mode=self.capture_hidden_mode,
608
616
  num_token_non_padded=self.num_token_non_padded,
609
617
  global_forward_mode=self.capture_forward_mode,
610
- lora_paths=lora_paths,
618
+ lora_ids=lora_ids,
611
619
  )
612
620
  self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
613
621
 
614
- if lora_paths is not None:
622
+ if lora_ids is not None:
615
623
  self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
616
624
 
617
625
  # Attention backend
@@ -728,10 +736,12 @@ class CudaGraphRunner:
728
736
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
729
737
  self.positions[:raw_num_token].copy_(forward_batch.positions)
730
738
 
739
+ seq_lens_cpu = None
731
740
  if forward_batch.seq_lens_cpu is not None:
732
741
  if bs != raw_bs:
733
742
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)
734
743
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
744
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
735
745
 
736
746
  if pp_proxy_tensors:
737
747
  for key in self.pp_proxy_tensors.keys():
@@ -746,7 +756,17 @@ class CudaGraphRunner:
746
756
  self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
747
757
  self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
748
758
  if enable_num_token_non_padded(self.model_runner.server_args):
749
- self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
759
+ num_token_non_padded = forward_batch.num_token_non_padded
760
+ if self.require_gathered_buffer:
761
+ tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
762
+ num_local_token_non_padded = torch.clamp(
763
+ num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
764
+ min=0,
765
+ max=tokens_per_rank,
766
+ )
767
+ self.num_token_non_padded.copy_(num_local_token_non_padded)
768
+ else:
769
+ self.num_token_non_padded.copy_(num_token_non_padded)
750
770
  if self.enable_two_batch_overlap:
751
771
  self.tbo_plugin.replay_prepare(
752
772
  forward_mode=self.capture_forward_mode,
@@ -765,7 +785,7 @@ class CudaGraphRunner:
765
785
  self.encoder_lens[:bs] if self.is_encoder_decoder else None,
766
786
  self.capture_forward_mode,
767
787
  forward_batch.spec_info,
768
- seq_lens_cpu=self.seq_lens_cpu[:bs],
788
+ seq_lens_cpu=seq_lens_cpu,
769
789
  )
770
790
 
771
791
  # Store fields
@@ -180,6 +180,9 @@ class ForwardBatch:
180
180
  # The sum of all sequence lengths
181
181
  seq_lens_sum: int
182
182
 
183
+ # The original sequence length without being chunked. Qwen-1M related.
184
+ orig_seq_lens: Optional[torch.Tensor] = None
185
+
183
186
  # Optional seq_lens on cpu
184
187
  seq_lens_cpu: Optional[torch.Tensor] = None
185
188
 
@@ -248,7 +251,7 @@ class ForwardBatch:
248
251
  encoder_out_cache_loc: Optional[torch.Tensor] = None
249
252
 
250
253
  # For LoRA
251
- lora_paths: Optional[List[str]] = None
254
+ lora_ids: Optional[List[str]] = None
252
255
 
253
256
  # For input embeddings
254
257
  input_embeds: Optional[torch.Tensor] = None
@@ -321,13 +324,14 @@ class ForwardBatch:
321
324
  encoder_out_cache_loc=batch.encoder_out_cache_loc,
322
325
  seq_lens_sum=batch.seq_lens_sum,
323
326
  seq_lens_cpu=batch.seq_lens_cpu,
327
+ orig_seq_lens=batch.orig_seq_lens,
324
328
  return_logprob=batch.return_logprob,
325
329
  top_logprobs_nums=batch.top_logprobs_nums,
326
330
  token_ids_logprobs=batch.token_ids_logprobs,
327
331
  is_extend_in_batch=batch.is_extend_in_batch,
328
332
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
329
333
  global_forward_mode=batch.global_forward_mode,
330
- lora_paths=batch.lora_paths,
334
+ lora_ids=batch.lora_ids,
331
335
  sampling_info=batch.sampling_info,
332
336
  req_to_token_pool=model_runner.req_to_token_pool,
333
337
  token_to_kv_pool=model_runner.token_to_kv_pool,
@@ -420,16 +424,12 @@ class ForwardBatch:
420
424
  batch.extend_prefix_lens, dtype=torch.int32
421
425
  ).to(device, non_blocking=True)
422
426
  ret.extend_num_tokens = batch.extend_num_tokens
423
- if support_triton(model_runner.server_args.attention_backend):
424
- positions, ret.extend_start_loc = compute_position_triton(
425
- ret.extend_prefix_lens,
426
- ret.extend_seq_lens,
427
- ret.extend_num_tokens,
428
- )
429
- else:
430
- positions, ret.extend_start_loc = compute_position_torch(
431
- ret.extend_prefix_lens, ret.extend_seq_lens
432
- )
427
+ positions, ret.extend_start_loc = compute_position(
428
+ model_runner.server_args.attention_backend,
429
+ ret.extend_prefix_lens,
430
+ ret.extend_seq_lens,
431
+ ret.extend_num_tokens,
432
+ )
433
433
  if ret.positions is None:
434
434
  ret.positions = positions
435
435
  ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
@@ -632,8 +632,10 @@ class ForwardBatch:
632
632
  self.dp_padding_mode = dp_padding_mode
633
633
 
634
634
  if dp_padding_mode.is_max_len():
635
- # when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
636
- # where transferred tokens should be padded to the same length.
635
+ # when DP gather mode is all gather, we will use
636
+ # all_gather_into_tensor to gather hidden states, where transferred
637
+ # tokens should be padded to the same length. We will also use
638
+ # reduce-scatter instead of all-reduce after MLP.
637
639
  max_num_tokens = max(global_num_tokens)
638
640
  global_num_tokens = [max_num_tokens] * sync_group_size
639
641
  buffer_len = max_num_tokens * sync_group_size
@@ -651,12 +653,30 @@ class ForwardBatch:
651
653
  else:
652
654
  num_tokens = global_num_tokens[0]
653
655
 
654
- if self.forward_mode.is_decode():
655
- setattr(self, "raw_bs", self.batch_size)
656
- self.batch_size = num_tokens
657
-
658
656
  bs = self.batch_size
659
657
 
658
+ if self.forward_mode.is_decode():
659
+ if self.is_extend_in_batch and dp_padding_mode.is_max_len():
660
+ setattr(self, "_original_forward_mode", self.forward_mode)
661
+ self.forward_mode = ForwardMode.EXTEND
662
+ self.extend_num_tokens = bs
663
+ self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
664
+ self.extend_prefix_lens = self.seq_lens - 1
665
+ self.extend_start_loc = torch.arange(
666
+ bs, dtype=torch.int32, device=self.seq_lens.device
667
+ )
668
+ self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
669
+ self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
670
+ self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
671
+ else:
672
+ setattr(self, "_original_batch_size", self.batch_size)
673
+ if self.spec_info is not None:
674
+ bs = self.batch_size = (
675
+ num_tokens // self.spec_info.num_tokens_per_batch
676
+ )
677
+ else:
678
+ bs = self.batch_size = num_tokens
679
+
660
680
  # padding
661
681
  self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
662
682
  self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -687,6 +707,7 @@ class ForwardBatch:
687
707
  if self.mrope_positions is not None:
688
708
  self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
689
709
 
710
+ # TODO: check if we need to pad other tensors
690
711
  if self.extend_seq_lens is not None:
691
712
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
692
713
 
@@ -710,7 +731,9 @@ class ForwardBatch:
710
731
 
711
732
  def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
712
733
 
713
- bs = getattr(self, "raw_bs", self.batch_size)
734
+ self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
735
+ self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
736
+ bs = self.batch_size
714
737
 
715
738
  if self.spec_info is not None:
716
739
  if self.forward_mode.is_decode(): # draft
@@ -882,6 +905,25 @@ class PPProxyTensors:
882
905
  return f"PPProxyTensors(tensors={self.tensors})"
883
906
 
884
907
 
908
+ def compute_position(
909
+ attn_backend: str,
910
+ extend_prefix_lens: torch.Tensor,
911
+ extend_seq_lens: torch.Tensor,
912
+ extend_seq_lens_sum: int,
913
+ ):
914
+ if support_triton(attn_backend):
915
+ positions, extend_start_loc = compute_position_triton(
916
+ extend_prefix_lens,
917
+ extend_seq_lens,
918
+ extend_seq_lens_sum,
919
+ )
920
+ else:
921
+ positions, extend_start_loc = compute_position_torch(
922
+ extend_prefix_lens, extend_seq_lens
923
+ )
924
+ return positions, extend_start_loc
925
+
926
+
885
927
  def compute_position_triton(
886
928
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
887
929
  ):