sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -0
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +7 -7
  6. sglang/srt/disaggregation/decode.py +8 -3
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +4 -5
  14. sglang/srt/entrypoints/openai/protocol.py +0 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +59 -265
  16. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  17. sglang/srt/function_call/ebnf_composer.py +1 -0
  18. sglang/srt/function_call/function_call_parser.py +2 -0
  19. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  20. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  21. sglang/srt/function_call/kimik2_detector.py +3 -3
  22. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  23. sglang/srt/jinja_template_utils.py +6 -0
  24. sglang/srt/layers/attention/aiter_backend.py +370 -107
  25. sglang/srt/layers/attention/ascend_backend.py +3 -0
  26. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  27. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  28. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  29. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  30. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  31. sglang/srt/layers/attention/vision.py +9 -1
  32. sglang/srt/layers/attention/wave_backend.py +627 -0
  33. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  34. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  35. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  36. sglang/srt/layers/communicator.py +8 -10
  37. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  38. sglang/srt/layers/linear.py +1 -0
  39. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  41. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  42. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  43. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  46. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  47. sglang/srt/layers/moe/topk.py +4 -1
  48. sglang/srt/layers/quantization/__init__.py +5 -3
  49. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  50. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  51. sglang/srt/layers/quantization/modelopt_quant.py +6 -11
  52. sglang/srt/layers/quantization/mxfp4.py +4 -1
  53. sglang/srt/layers/quantization/w4afp8.py +20 -11
  54. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  55. sglang/srt/layers/rotary_embedding.py +281 -2
  56. sglang/srt/lora/backend/base_backend.py +3 -23
  57. sglang/srt/lora/layers.py +60 -114
  58. sglang/srt/lora/lora.py +17 -62
  59. sglang/srt/lora/lora_manager.py +12 -48
  60. sglang/srt/lora/lora_registry.py +20 -9
  61. sglang/srt/lora/mem_pool.py +20 -63
  62. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  63. sglang/srt/lora/utils.py +25 -58
  64. sglang/srt/managers/cache_controller.py +21 -29
  65. sglang/srt/managers/detokenizer_manager.py +1 -1
  66. sglang/srt/managers/io_struct.py +6 -6
  67. sglang/srt/managers/mm_utils.py +1 -2
  68. sglang/srt/managers/multimodal_processor.py +1 -1
  69. sglang/srt/managers/schedule_batch.py +35 -20
  70. sglang/srt/managers/schedule_policy.py +6 -6
  71. sglang/srt/managers/scheduler.py +15 -7
  72. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  73. sglang/srt/managers/tokenizer_manager.py +25 -26
  74. sglang/srt/mem_cache/allocator.py +61 -87
  75. sglang/srt/mem_cache/hicache_storage.py +1 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  77. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  78. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  79. sglang/srt/mem_cache/radix_cache.py +2 -5
  80. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  81. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  82. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  83. sglang/srt/model_executor/cuda_graph_runner.py +22 -3
  84. sglang/srt/model_executor/forward_batch_info.py +26 -5
  85. sglang/srt/model_executor/model_runner.py +129 -35
  86. sglang/srt/model_loader/loader.py +18 -6
  87. sglang/srt/models/deepseek_v2.py +74 -35
  88. sglang/srt/models/gemma2.py +0 -34
  89. sglang/srt/models/gemma3n_mm.py +8 -9
  90. sglang/srt/models/glm4.py +6 -0
  91. sglang/srt/models/glm4_moe.py +9 -9
  92. sglang/srt/models/glm4v.py +589 -0
  93. sglang/srt/models/glm4v_moe.py +400 -0
  94. sglang/srt/models/gpt_oss.py +136 -19
  95. sglang/srt/models/granite.py +0 -25
  96. sglang/srt/models/llama.py +0 -25
  97. sglang/srt/models/llama4.py +1 -1
  98. sglang/srt/models/qwen2_5_vl.py +7 -3
  99. sglang/srt/models/qwen2_audio.py +10 -9
  100. sglang/srt/models/qwen3.py +0 -24
  101. sglang/srt/models/registry.py +1 -1
  102. sglang/srt/models/torch_native_llama.py +0 -24
  103. sglang/srt/multimodal/processors/base_processor.py +23 -13
  104. sglang/srt/multimodal/processors/glm4v.py +132 -0
  105. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  106. sglang/srt/reasoning_parser.py +316 -0
  107. sglang/srt/server_args.py +115 -139
  108. sglang/srt/speculative/eagle_worker.py +16 -0
  109. sglang/srt/two_batch_overlap.py +12 -4
  110. sglang/srt/utils.py +3 -3
  111. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  112. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  113. sglang/test/doc_patch.py +59 -0
  114. sglang/test/few_shot_gsm8k.py +1 -1
  115. sglang/test/few_shot_gsm8k_engine.py +1 -1
  116. sglang/test/run_eval.py +4 -1
  117. sglang/test/simple_eval_common.py +6 -0
  118. sglang/test/simple_eval_gpqa.py +2 -0
  119. sglang/test/test_fp4_moe.py +118 -36
  120. sglang/utils.py +1 -1
  121. sglang/version.py +1 -1
  122. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
  123. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
  124. sglang/lang/backend/__init__.py +0 -0
  125. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  126. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  127. /sglang/{api.py → lang/api.py} +0 -0
  128. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  129. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -5,9 +5,9 @@ import logging
5
5
  import os
6
6
  import signal
7
7
  import threading
8
- from collections import OrderedDict
8
+ from abc import ABC, abstractmethod
9
9
  from functools import wraps
10
- from typing import List, Optional
10
+ from typing import List, Optional, Tuple
11
11
 
12
12
  import torch
13
13
 
@@ -17,6 +17,75 @@ from sglang.srt.mem_cache.storage.hf3fs.client_hf3fs import Hf3fsClient
17
17
  logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
+ class Hf3fsMetadataInterface(ABC):
21
+ """Interface for HF3FS metadata operations."""
22
+
23
+ @abstractmethod
24
+ def initialize(self, rank: int, num_pages: int) -> None:
25
+ """Initialize the metadata service with specified number of pages."""
26
+ pass
27
+
28
+ @abstractmethod
29
+ def reserve_and_allocate_page_indices(
30
+ self,
31
+ rank: int,
32
+ keys: List[Tuple[str, str]],
33
+ ) -> List[Tuple[bool, int]]:
34
+ """
35
+ Reserve and allocate page indices for the specified keys.
36
+ Args:
37
+ rank: The rank of the process.
38
+ keys: The keys to reserve and allocate page indices for. Each tuple contains a key and the key of its prefix block.
39
+ Returns:
40
+ List[Tuple[bool, int]]: A list of tuples, where each tuple contains a boolean indicating whether the key has existed and an integer indicating the allocated page index.
41
+ """
42
+ pass
43
+
44
+ @abstractmethod
45
+ def confirm_write(
46
+ self,
47
+ rank: int,
48
+ written_keys_to_confirm: List[Tuple[str, int]],
49
+ pages_to_release: List[int],
50
+ ) -> None:
51
+ """
52
+ Confirm that key-value pairs have been successfully written to storage.
53
+ Args:
54
+ rank: The rank of the process.
55
+ written_keys_to_confirm: A list of tuples, where each tuple contains a key and its corresponding page index.
56
+ pages_to_release: A list of page indices to be released.
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ def get_page_indices(self, rank: int, keys: List[str]) -> List[Optional[int]]:
62
+ """
63
+ Get page indices for the specified keys.
64
+ Args:
65
+ rank: The rank of the process.
66
+ keys: A list of keys.
67
+ Returns:
68
+ List[Optional[int]]: A list of integers representing the page indices for the specified keys.
69
+ If a key is not found, the corresponding index will be None.
70
+ """
71
+ pass
72
+
73
+ @abstractmethod
74
+ def delete_keys(self, rank: int, keys: List[str]) -> None:
75
+ """Delete specified keys and their associated pages."""
76
+ pass
77
+
78
+ @abstractmethod
79
+ def exists(self, rank: int, keys: List[str]) -> List[bool]:
80
+ """Check if the specified keys exist."""
81
+ pass
82
+
83
+ @abstractmethod
84
+ def clear(self, rank: int) -> None:
85
+ """Clear all key-value pairs and page allocations for the specified rank."""
86
+ pass
87
+
88
+
20
89
  class AtomicCounter:
21
90
  def __init__(self, n: int):
22
91
  assert n > 0
@@ -48,32 +117,32 @@ class HiCacheHF3FS(HiCacheStorage):
48
117
 
49
118
  def __init__(
50
119
  self,
120
+ rank: int,
51
121
  file_path: str,
52
122
  file_size: int,
53
123
  numjobs: int,
54
124
  bytes_per_page: int,
55
125
  entries: int,
56
126
  dtype: torch.dtype,
127
+ metadata_client: Hf3fsMetadataInterface,
57
128
  ):
129
+ self.rank = rank
58
130
  self.file_path = file_path
59
131
  self.file_size = file_size
60
132
  self.numjobs = numjobs
61
133
  self.bytes_per_page = bytes_per_page
62
134
  self.entries = entries
63
135
  self.dtype = dtype
136
+ self.metadata_client = metadata_client
64
137
 
65
138
  self.numel = self.bytes_per_page // self.dtype.itemsize
66
-
67
139
  self.num_pages = self.file_size // self.bytes_per_page
68
140
 
69
141
  logger.info(
70
- "HiCacheHF3FS "
71
- f"file_path = {self.file_path}, "
72
- f"file_size = {self.file_size/(2**30):.2f} GB, "
73
- f"numjobs = {self.numjobs}, "
74
- f"bytes_per_page = {self.bytes_per_page/(2**20):.2f} MB, "
75
- f"entries = {self.entries}, "
76
- f"num_pages = {self.num_pages}"
142
+ f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
143
+ f"file_path={self.file_path}, "
144
+ f"file_size={self.file_size / (2 ** 30):.2f} GB, "
145
+ f"num_pages={self.num_pages}"
77
146
  )
78
147
 
79
148
  self.ac = AtomicCounter(self.numjobs)
@@ -84,15 +153,11 @@ class HiCacheHF3FS(HiCacheStorage):
84
153
  for _ in range(numjobs)
85
154
  ]
86
155
  self.executor = concurrent.futures.ThreadPoolExecutor(
87
- max_workers=self.numjobs, thread_name_prefix="HiCacheHF3FS"
156
+ max_workers=self.numjobs, thread_name_prefix=f"HiCacheHF3FS-Rank{self.rank}"
88
157
  )
89
158
 
90
- # Implemented a preliminary single-file page_hash -> file_offset index as interim storage.
91
- # Future iterations may adopt a global KVCache manager to coordinate external cache instances
92
- # through centralized metadata orchestration.
159
+ self.metadata_client.initialize(self.rank, self.num_pages)
93
160
  self.lock = threading.RLock()
94
- self.free_pages = list(range(self.num_pages))
95
- self.key_to_index = OrderedDict()
96
161
 
97
162
  atexit.register(self.close)
98
163
 
@@ -104,15 +169,22 @@ class HiCacheHF3FS(HiCacheStorage):
104
169
  def from_env_config(
105
170
  rank: int, bytes_per_page: int, dtype: torch.dtype
106
171
  ) -> "HiCacheHF3FS":
172
+ from sglang.srt.mem_cache.storage.hf3fs.mini_3fs_metadata_server import (
173
+ Hf3fsGlobalMetadataClient,
174
+ Hf3fsLocalMetadataClient,
175
+ )
176
+
107
177
  config_path = os.getenv(HiCacheHF3FS.default_env_var)
108
178
  if not config_path:
109
179
  return HiCacheHF3FS(
180
+ rank=rank,
110
181
  file_path=f"/data/hicache.{rank}.bin",
111
182
  file_size=1 << 40,
112
183
  numjobs=16,
113
184
  bytes_per_page=bytes_per_page,
114
185
  entries=8,
115
186
  dtype=dtype,
187
+ metadata_client=Hf3fsLocalMetadataClient(),
116
188
  )
117
189
 
118
190
  try:
@@ -121,6 +193,7 @@ class HiCacheHF3FS(HiCacheStorage):
121
193
  except Exception as e:
122
194
  raise RuntimeError(f"Failed to load config from {config_path}: {str(e)}")
123
195
 
196
+ # Check required keys (metadata_server_url is now optional)
124
197
  required_keys = {
125
198
  "file_path_prefix",
126
199
  "file_size",
@@ -131,19 +204,33 @@ class HiCacheHF3FS(HiCacheStorage):
131
204
  if missing_keys:
132
205
  raise ValueError(f"Missing required keys in config: {missing_keys}")
133
206
 
207
+ # Choose metadata client based on configuration
208
+ if "metadata_server_url" in config and config["metadata_server_url"]:
209
+ # Use global metadata client to connect to metadata server
210
+ metadata_server_url = config["metadata_server_url"]
211
+ metadata_client = Hf3fsGlobalMetadataClient(metadata_server_url)
212
+ logger.info(
213
+ f"Using global metadata client with server url: {metadata_server_url}"
214
+ )
215
+ else:
216
+ # Use local metadata client for single-machine deployment
217
+ metadata_client = Hf3fsLocalMetadataClient()
218
+
134
219
  return HiCacheHF3FS(
220
+ rank=rank,
135
221
  file_path=f"{config['file_path_prefix']}.{rank}.bin",
136
222
  file_size=int(config["file_size"]),
137
223
  numjobs=int(config["numjobs"]),
138
224
  bytes_per_page=bytes_per_page,
139
225
  entries=int(config["entries"]),
140
226
  dtype=dtype,
227
+ metadata_client=metadata_client,
141
228
  )
142
229
 
143
230
  def get(
144
231
  self, key: str, target_location: Optional[torch.Tensor] = None
145
232
  ) -> torch.Tensor | None:
146
- return self.batch_get([key], target_location)[0]
233
+ return self.batch_get([key], [target_location] if target_location else None)[0]
147
234
 
148
235
  @synchronized()
149
236
  def batch_get(
@@ -151,14 +238,14 @@ class HiCacheHF3FS(HiCacheStorage):
151
238
  keys: List[str],
152
239
  target_locations: Optional[List[torch.Tensor]] = None,
153
240
  ) -> List[torch.Tensor | None]:
241
+ page_indices = self.metadata_client.get_page_indices(self.rank, keys)
242
+
154
243
  batch_indices, file_offsets = [], []
155
- for i, key in enumerate(keys):
156
- if key not in self.key_to_index:
157
- continue
158
- batch_indices.append(i)
159
- file_offsets.append(self.key_to_index[key] * self.bytes_per_page)
160
- self.key_to_index.move_to_end(key)
161
- # TODO: target_locations
244
+ for i, page_index in enumerate(page_indices):
245
+ if page_index is not None:
246
+ batch_indices.append(i)
247
+ file_offsets.append(page_index * self.bytes_per_page)
248
+
162
249
  file_results = [
163
250
  torch.empty(self.numel, dtype=self.dtype) for _ in range(len(batch_indices))
164
251
  ]
@@ -180,7 +267,9 @@ class HiCacheHF3FS(HiCacheStorage):
180
267
  if read_result == self.bytes_per_page:
181
268
  results[batch_index] = file_result
182
269
  else:
183
- logger.error(f"HiCacheHF3FS get {keys[batch_index]} failed")
270
+ logger.error(
271
+ f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
272
+ )
184
273
 
185
274
  return results
186
275
 
@@ -188,13 +277,21 @@ class HiCacheHF3FS(HiCacheStorage):
188
277
  return self.batch_set([key], [value])
189
278
 
190
279
  def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
191
- indices = self.get_batch_set_indices(keys)
280
+ # Todo: Add prefix block's hash key
281
+ key_with_prefix = [(key, "") for key in keys]
282
+ indices = self.metadata_client.reserve_and_allocate_page_indices(
283
+ self.rank, key_with_prefix
284
+ )
285
+
192
286
  batch_indices, file_offsets, file_values = [], [], []
193
- for i, (value, (is_written, index)) in enumerate(zip(values, indices)):
194
- if is_written or index == -1:
287
+ pages_to_release = []
288
+
289
+ for i, (value, (is_written, page_index)) in enumerate(zip(values, indices)):
290
+ if is_written or page_index == -1:
195
291
  continue
292
+
196
293
  batch_indices.append(i)
197
- file_offsets.append(index * self.bytes_per_page)
294
+ file_offsets.append(page_index * self.bytes_per_page)
198
295
  file_values.append(value.contiguous())
199
296
 
200
297
  futures = [
@@ -211,62 +308,37 @@ class HiCacheHF3FS(HiCacheStorage):
211
308
  for result in future.result()
212
309
  ]
213
310
 
311
+ written_keys_to_confirm = []
214
312
  results = [index[0] for index in indices]
215
313
  for batch_index, write_result in zip(batch_indices, write_results):
216
314
  key = keys[batch_index]
217
- index = indices[batch_index][1]
315
+ page_index = indices[batch_index][1]
218
316
  if write_result:
219
- self.key_to_index[key] = index
220
- self.key_to_index.move_to_end(key)
317
+ written_keys_to_confirm.append((key, page_index))
221
318
  else:
222
- logger.error(f"HiCacheHF3FS set {key} failed")
223
- self.free_pages.append(index)
319
+ logger.error(f"[Rank {self.rank}] HiCacheHF3FS set {key} failed")
320
+ pages_to_release.append(page_index)
224
321
  results[batch_index] = write_result
225
- return all(results)
226
-
227
- @synchronized()
228
- def get_batch_set_indices(self, keys: List[str]) -> list:
229
- ionum = len(keys)
230
- # results: tuples of (is_written: bool, page_idx: int)
231
- # - is_written: True = hit (no I/O), False = write (miss)
232
- # - page_idx: page storing data
233
- results = [None] * min(ionum, self.num_pages)
234
- if ionum > self.num_pages:
235
- results.extend([(False, -1)] * (ionum - self.num_pages))
236
-
237
- new_keys = []
238
- for batch_index, key in enumerate(keys[: self.num_pages]):
239
- if key in self.key_to_index:
240
- results[batch_index] = (True, self.key_to_index[key])
241
- self.key_to_index.move_to_end(key)
242
- else:
243
- new_keys.append((batch_index, key))
244
322
 
245
- for batch_index, _ in new_keys:
246
- index = (
247
- self.free_pages.pop()
248
- if len(self.free_pages) > 0
249
- else self.key_to_index.popitem(last=False)[1]
323
+ if len(written_keys_to_confirm) > 0 or len(pages_to_release) > 0:
324
+ self.metadata_client.confirm_write(
325
+ self.rank, written_keys_to_confirm, pages_to_release
250
326
  )
251
- results[batch_index] = (False, index)
252
327
 
253
- return results
328
+ return all(results)
254
329
 
255
330
  @synchronized()
256
331
  def delete(self, key: str) -> None:
257
- if key not in self.key_to_index:
258
- return
259
- index = self.key_to_index.pop(key)
260
- self.free_pages.append(index)
332
+ self.metadata_client.delete_keys(self.rank, [key])
261
333
 
262
334
  @synchronized()
263
335
  def exists(self, key: str) -> bool:
264
- return key in self.key_to_index
336
+ result = self.metadata_client.exists(self.rank, [key])
337
+ return result[0] if result else False
265
338
 
266
339
  @synchronized()
267
340
  def clear(self) -> None:
268
- self.free_pages = list(range(self.num_pages))
269
- self.key_to_index.clear()
341
+ self.metadata_client.clear(self.rank)
270
342
 
271
343
  def close(self) -> None:
272
344
  try:
@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
18
18
  logger = logging.getLogger(__name__)
19
19
 
20
20
 
21
- def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
21
+ def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
22
22
  local_rank = get_tensor_model_parallel_rank()
23
23
  prefix_str = ""
24
- if prefix_block_key:
25
- if len(prefix_block_key):
26
- prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
27
- current_token_ids_bytes = np.array(current_page_ids).tobytes()
24
+ if prior_hash:
25
+ prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
26
+ current_token_ids_bytes = np.array(token_ids).tobytes()
28
27
  current_hash_object = hashlib.sha256(current_token_ids_bytes)
29
28
  current_hash_hex = current_hash_object.hexdigest()
30
29
  return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
@@ -224,13 +223,11 @@ class MooncakeStore(HiCacheStorage):
224
223
 
225
224
  def exists(self, keys) -> bool | dict:
226
225
  _keys = []
227
- local_rank = torch.cuda.current_device()
228
226
  for key in keys:
229
227
  if key is None:
230
228
  return None
231
- # Since mooncake store is stored in layer by layer,
232
- # only the first layer is checked here.
233
- _keys.append(f"{key}_{local_rank}_k")
229
+
230
+ _keys.append(f"{key}_k")
234
231
  result = {k: v for k, v in zip(keys, self.store.batch_is_exist(_keys))}
235
232
  return result
236
233
 
@@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
33
  set_graph_pool_id,
34
34
  )
35
35
  from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
36
- from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
36
+ from sglang.srt.layers.dp_attention import (
37
+ DPPaddingMode,
38
+ get_attention_tp_rank,
39
+ get_attention_tp_size,
40
+ )
37
41
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
38
42
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
39
43
  from sglang.srt.model_executor.forward_batch_info import (
@@ -255,6 +259,9 @@ class CudaGraphRunner:
255
259
  self.dp_size = model_runner.server_args.dp_size
256
260
  self.pp_size = model_runner.server_args.pp_size
257
261
 
262
+ self.attn_tp_size = get_attention_tp_size()
263
+ self.attn_tp_rank = get_attention_tp_rank()
264
+
258
265
  # Batch sizes to capture
259
266
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
260
267
  rank0_log(f"Capture cuda graph bs {self.capture_bs}")
@@ -729,10 +736,12 @@ class CudaGraphRunner:
729
736
  self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
730
737
  self.positions[:raw_num_token].copy_(forward_batch.positions)
731
738
 
739
+ seq_lens_cpu = None
732
740
  if forward_batch.seq_lens_cpu is not None:
733
741
  if bs != raw_bs:
734
742
  self.seq_lens_cpu.fill_(self.seq_len_fill_value)
735
743
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
744
+ seq_lens_cpu = self.seq_lens_cpu[:bs]
736
745
 
737
746
  if pp_proxy_tensors:
738
747
  for key in self.pp_proxy_tensors.keys():
@@ -747,7 +756,17 @@ class CudaGraphRunner:
747
756
  self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
748
757
  self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
749
758
  if enable_num_token_non_padded(self.model_runner.server_args):
750
- self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
759
+ num_token_non_padded = forward_batch.num_token_non_padded
760
+ if self.require_gathered_buffer:
761
+ tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
762
+ num_local_token_non_padded = torch.clamp(
763
+ num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
764
+ min=0,
765
+ max=tokens_per_rank,
766
+ )
767
+ self.num_token_non_padded.copy_(num_local_token_non_padded)
768
+ else:
769
+ self.num_token_non_padded.copy_(num_token_non_padded)
751
770
  if self.enable_two_batch_overlap:
752
771
  self.tbo_plugin.replay_prepare(
753
772
  forward_mode=self.capture_forward_mode,
@@ -766,7 +785,7 @@ class CudaGraphRunner:
766
785
  self.encoder_lens[:bs] if self.is_encoder_decoder else None,
767
786
  self.capture_forward_mode,
768
787
  forward_batch.spec_info,
769
- seq_lens_cpu=self.seq_lens_cpu[:bs],
788
+ seq_lens_cpu=seq_lens_cpu,
770
789
  )
771
790
 
772
791
  # Store fields
@@ -653,12 +653,30 @@ class ForwardBatch:
653
653
  else:
654
654
  num_tokens = global_num_tokens[0]
655
655
 
656
- if self.forward_mode.is_decode():
657
- setattr(self, "raw_bs", self.batch_size)
658
- self.batch_size = num_tokens
659
-
660
656
  bs = self.batch_size
661
657
 
658
+ if self.forward_mode.is_decode():
659
+ if self.is_extend_in_batch and dp_padding_mode.is_max_len():
660
+ setattr(self, "_original_forward_mode", self.forward_mode)
661
+ self.forward_mode = ForwardMode.EXTEND
662
+ self.extend_num_tokens = bs
663
+ self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
664
+ self.extend_prefix_lens = self.seq_lens - 1
665
+ self.extend_start_loc = torch.arange(
666
+ bs, dtype=torch.int32, device=self.seq_lens.device
667
+ )
668
+ self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
669
+ self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
670
+ self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
671
+ else:
672
+ setattr(self, "_original_batch_size", self.batch_size)
673
+ if self.spec_info is not None:
674
+ bs = self.batch_size = (
675
+ num_tokens // self.spec_info.num_tokens_per_batch
676
+ )
677
+ else:
678
+ bs = self.batch_size = num_tokens
679
+
662
680
  # padding
663
681
  self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
664
682
  self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -689,6 +707,7 @@ class ForwardBatch:
689
707
  if self.mrope_positions is not None:
690
708
  self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
691
709
 
710
+ # TODO: check if we need to pad other tensors
692
711
  if self.extend_seq_lens is not None:
693
712
  self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
694
713
 
@@ -712,7 +731,9 @@ class ForwardBatch:
712
731
 
713
732
  def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
714
733
 
715
- bs = getattr(self, "raw_bs", self.batch_size)
734
+ self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
735
+ self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
736
+ bs = self.batch_size
716
737
 
717
738
  if self.spec_info is not None:
718
739
  if self.forward_mode.is_decode(): # draft