sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,9 @@ limitations under the License.
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ from dataclasses import dataclass
19
+
20
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams
18
21
  from sglang.srt.layers.attention.nsa import index_buf_accessor
19
22
  from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
20
23
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -109,17 +112,38 @@ class ReqToTokenPool:
109
112
 
110
113
 
111
114
  class MambaPool:
115
+ @dataclass(frozen=True, kw_only=True)
116
+ class State:
117
+ conv: torch.Tensor
118
+ temporal: torch.Tensor
119
+
120
+ def at_layer_idx(self, layer: int):
121
+ return type(self)(**{k: v[layer] for k, v in vars(self).items()})
122
+
123
+ def mem_usage_bytes(self):
124
+ return sum(get_tensor_size_bytes(t) for t in vars(self).values())
125
+
126
+ @dataclass(frozen=True, kw_only=True)
127
+ class SpeculativeState(State):
128
+ intermediate_ssm: torch.Tensor
129
+ intermediate_conv_window: torch.Tensor
130
+
112
131
  def __init__(
113
132
  self,
133
+ *,
114
134
  size: int,
115
- conv_dtype: torch.dtype,
116
- ssm_dtype: torch.dtype,
117
- num_mamba_layers: int,
118
- conv_state_shape: Tuple[int, int],
119
- temporal_state_shape: Tuple[int, int],
135
+ cache_params: "Mamba2CacheParams",
120
136
  device: str,
121
137
  speculative_num_draft_tokens: Optional[int] = None,
122
138
  ):
139
+ conv_state_shape = cache_params.shape.conv
140
+ temporal_state_shape = cache_params.shape.temporal
141
+ conv_dtype = cache_params.dtype.conv
142
+ ssm_dtype = cache_params.dtype.temporal
143
+ num_mamba_layers = len(cache_params.layers)
144
+
145
+ # assume conv_state = (dim, state_len)
146
+ assert conv_state_shape[0] > conv_state_shape[1]
123
147
  conv_state = torch.zeros(
124
148
  size=(num_mamba_layers, size + 1) + conv_state_shape,
125
149
  dtype=conv_dtype,
@@ -158,11 +182,11 @@ class MambaPool:
158
182
  dtype=conv_dtype,
159
183
  device="cuda",
160
184
  )
161
- self.mamba_cache = (
162
- conv_state,
163
- temporal_state,
164
- intermediate_ssm_state_cache,
165
- intermediate_conv_window_cache,
185
+ self.mamba_cache = self.SpeculativeState(
186
+ conv=conv_state,
187
+ temporal=temporal_state,
188
+ intermediate_ssm=intermediate_ssm_state_cache,
189
+ intermediate_conv_window=intermediate_conv_window_cache,
166
190
  )
167
191
  logger.info(
168
192
  f"Mamba Cache is allocated. "
@@ -172,7 +196,7 @@ class MambaPool:
172
196
  f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
173
197
  )
174
198
  else:
175
- self.mamba_cache = (conv_state, temporal_state)
199
+ self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
176
200
  logger.info(
177
201
  f"Mamba Cache is allocated. "
178
202
  f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
@@ -180,16 +204,14 @@ class MambaPool:
180
204
  )
181
205
  self.size = size
182
206
  self.free_slots = list(range(size))
183
- self.mem_usage = self.get_mamba_size() / GB
184
-
185
- def get_mamba_params_all_layers(self):
186
- return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
207
+ self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
187
208
 
188
- def get_mamba_params(self, layer_id: int):
189
- return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
209
+ def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
210
+ assert isinstance(self.mamba_cache, self.SpeculativeState)
211
+ return self.mamba_cache
190
212
 
191
- def get_mamba_size(self):
192
- return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
213
+ def mamba2_layer_cache(self, layer_id: int):
214
+ return self.mamba_cache.at_layer_idx(layer_id)
193
215
 
194
216
  def available_size(self):
195
217
  return len(self.free_slots)
@@ -208,7 +230,9 @@ class MambaPool:
208
230
  self.free_slots.append(free_index)
209
231
  else:
210
232
  self.free_slots.extend(free_index)
211
- self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
233
+ self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
234
+ :, free_index
235
+ ] = 0
212
236
 
213
237
  def clear(self):
214
238
  self.free_slots = list(range(self.size))
@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
219
243
 
220
244
  def __init__(
221
245
  self,
246
+ *,
222
247
  size: int,
223
248
  max_context_len: int,
224
249
  device: str,
225
250
  enable_memory_saver: bool,
226
- conv_dtype: torch.dtype,
227
- ssm_dtype: torch.dtype,
228
- mamba_layers: List[int],
229
- conv_state_shape: Tuple[int, int],
230
- temporal_state_shape: Tuple[int, int],
231
- speculative_num_draft_tokens: int,
251
+ cache_params: "Mamba2CacheParams",
252
+ speculative_num_draft_tokens: int = None,
232
253
  ):
233
254
  super().__init__(
234
255
  size=size,
@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
238
259
  )
239
260
 
240
261
  self.mamba_pool = MambaPool(
241
- size,
242
- conv_dtype,
243
- ssm_dtype,
244
- len(mamba_layers),
245
- conv_state_shape,
246
- temporal_state_shape,
247
- device,
248
- speculative_num_draft_tokens,
262
+ size=size,
263
+ cache_params=cache_params,
264
+ device=device,
265
+ speculative_num_draft_tokens=speculative_num_draft_tokens,
249
266
  )
250
- self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
267
+ self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
251
268
 
252
269
  self.device = device
253
270
  self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
287
304
  def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
288
305
  return self.req_index_to_mamba_index_mapping[req_indices]
289
306
 
290
- def get_mamba_params(self, layer_id: int):
307
+ def mamba2_layer_cache(self, layer_id: int):
291
308
  assert layer_id in self.mamba_map
292
- return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
309
+ return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
293
310
 
294
- def get_mamba_params_all_layers(self):
295
- return self.mamba_pool.get_mamba_params_all_layers()
311
+ def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
312
+ return self.mamba_pool.get_speculative_mamba2_params_all_layers()
296
313
 
297
314
  # For chunk prefill, we can not free mamba cache, we need use it in the future
298
315
  def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
@@ -415,6 +432,7 @@ class MHATokenToKVPool(KVCache):
415
432
  enable_memory_saver: bool,
416
433
  start_layer: Optional[int] = None,
417
434
  end_layer: Optional[int] = None,
435
+ enable_kv_cache_copy: bool = False,
418
436
  ):
419
437
  super().__init__(
420
438
  size,
@@ -446,8 +464,57 @@ class MHATokenToKVPool(KVCache):
446
464
 
447
465
  self.device_module = torch.get_device_module(self.device)
448
466
  self.alt_stream = self.device_module.Stream() if _is_cuda else None
467
+
468
+ if enable_kv_cache_copy:
469
+ self._init_kv_copy_and_warmup()
470
+ else:
471
+ self._kv_copy_config = None
472
+
449
473
  self._finalize_allocation_log(size)
450
474
 
475
+ def _init_kv_copy_and_warmup(self):
476
+ # Heuristics for KV copy tiling
477
+ _KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
478
+ _KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
479
+ _KV_COPY_TILE_SIZE_LARGE = 512
480
+ _KV_COPY_TILE_SIZE_MEDIUM = 256
481
+ _KV_COPY_TILE_SIZE_SMALL = 128
482
+ _KV_COPY_NUM_WARPS_LARGE_TILE = 8
483
+ _KV_COPY_NUM_WARPS_SMALL_TILE = 4
484
+
485
+ stride_bytes = int(self.data_strides[0].item())
486
+ if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
487
+ bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
488
+ elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
489
+ bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
490
+ else:
491
+ bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
492
+
493
+ self._kv_copy_config = {
494
+ "bytes_per_tile": bytes_per_tile,
495
+ "byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
496
+ "num_warps": (
497
+ _KV_COPY_NUM_WARPS_SMALL_TILE
498
+ if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
499
+ else _KV_COPY_NUM_WARPS_LARGE_TILE
500
+ ),
501
+ }
502
+
503
+ dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
504
+ grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
505
+
506
+ copy_all_layer_kv_cache_tiled[grid](
507
+ self.data_ptrs,
508
+ self.data_strides,
509
+ dummy_loc,
510
+ dummy_loc,
511
+ 1,
512
+ 1,
513
+ BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
514
+ num_warps=self._kv_copy_config["num_warps"],
515
+ num_stages=2,
516
+ )
517
+
451
518
  def _create_buffers(self):
452
519
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
453
520
  with (
@@ -642,13 +709,28 @@ class MHATokenToKVPool(KVCache):
642
709
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
643
710
 
644
711
  def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
645
- copy_all_layer_kv_cache[(len(self.data_ptrs),)](
712
+ N = tgt_loc.numel()
713
+ if N == 0:
714
+ return
715
+
716
+ assert (
717
+ self._kv_copy_config is not None
718
+ ), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
719
+
720
+ cfg = self._kv_copy_config
721
+ N_upper = next_power_of_2(N)
722
+ grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
723
+
724
+ copy_all_layer_kv_cache_tiled[grid](
646
725
  self.data_ptrs,
647
726
  self.data_strides,
648
727
  tgt_loc,
649
728
  src_loc,
650
- len(tgt_loc),
651
- next_power_of_2(len(tgt_loc)),
729
+ N,
730
+ N_upper,
731
+ BYTES_PER_TILE=cfg["bytes_per_tile"],
732
+ num_warps=cfg["num_warps"],
733
+ num_stages=2,
652
734
  )
653
735
 
654
736
 
@@ -749,6 +831,7 @@ class SWAKVPool(KVCache):
749
831
  self,
750
832
  size: int,
751
833
  size_swa: int,
834
+ dtype: torch.dtype,
752
835
  swa_attention_layer_ids: List[int],
753
836
  full_attention_layer_ids: List[int],
754
837
  enable_kvcache_transpose: bool,
@@ -757,6 +840,7 @@ class SWAKVPool(KVCache):
757
840
  ):
758
841
  self.size = size
759
842
  self.size_swa = size_swa
843
+ self.dtype = dtype
760
844
  self.swa_layer_nums = len(swa_attention_layer_ids)
761
845
  self.full_layer_nums = len(full_attention_layer_ids)
762
846
  kwargs["page_size"] = 1
@@ -766,11 +850,13 @@ class SWAKVPool(KVCache):
766
850
 
767
851
  self.swa_kv_pool = token_to_kv_pool_class(
768
852
  size=size_swa,
853
+ dtype=dtype,
769
854
  layer_num=self.swa_layer_nums,
770
855
  **kwargs,
771
856
  )
772
857
  self.full_kv_pool = token_to_kv_pool_class(
773
858
  size=size,
859
+ dtype=dtype,
774
860
  layer_num=self.full_layer_nums,
775
861
  **kwargs,
776
862
  )
@@ -1091,7 +1177,9 @@ class MLATokenToKVPool(KVCache):
1091
1177
  dtype=torch.uint64,
1092
1178
  device=self.device,
1093
1179
  )
1094
- self._finalize_allocation_log(size)
1180
+ if not use_nsa:
1181
+ # NSA will allocate indexer KV cache later and then log the total size
1182
+ self._finalize_allocation_log(size)
1095
1183
 
1096
1184
  def get_kv_size_bytes(self):
1097
1185
  assert hasattr(self, "kv_buffer")
@@ -1212,6 +1300,9 @@ class MLATokenToKVPool(KVCache):
1212
1300
 
1213
1301
 
1214
1302
  class NSATokenToKVPool(MLATokenToKVPool):
1303
+ quant_block_size = 128
1304
+ index_k_with_scale_buffer_dtype = torch.uint8
1305
+
1215
1306
  def __init__(
1216
1307
  self,
1217
1308
  size: int,
@@ -1245,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool):
1245
1336
  # num head == 1 and head dim == 128 for index_k in NSA
1246
1337
  assert index_head_dim == 128
1247
1338
 
1248
- self.quant_block_size = 128
1249
-
1250
1339
  assert self.page_size == 64
1251
1340
  self.index_k_with_scale_buffer = [
1252
1341
  torch.zeros(
@@ -1261,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
1261
1350
  self.page_size
1262
1351
  * (index_head_dim + index_head_dim // self.quant_block_size * 4),
1263
1352
  ),
1264
- dtype=torch.uint8,
1353
+ dtype=self.index_k_with_scale_buffer_dtype,
1265
1354
  device=device,
1266
1355
  )
1267
1356
  for _ in range(layer_num)
1268
1357
  ]
1358
+ self._finalize_allocation_log(size)
1269
1359
 
1270
1360
  def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
1271
1361
  if self.layer_transfer_counter is not None:
@@ -1307,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
1307
1397
  pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
1308
1398
  )
1309
1399
 
1400
+ def get_kv_size_bytes(self):
1401
+ kv_size_bytes = super().get_kv_size_bytes()
1402
+ for index_k_cache in self.index_k_with_scale_buffer:
1403
+ kv_size_bytes += get_tensor_size_bytes(index_k_cache)
1404
+ return kv_size_bytes
1405
+
1310
1406
 
1311
1407
  class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
1312
1408
  def __init__(
@@ -1584,38 +1680,36 @@ class DoubleSparseTokenToKVPool(KVCache):
1584
1680
 
1585
1681
 
1586
1682
  @triton.jit
1587
- def copy_all_layer_kv_cache(
1683
+ def copy_all_layer_kv_cache_tiled(
1588
1684
  data_ptrs,
1589
1685
  strides,
1590
1686
  tgt_loc_ptr,
1591
1687
  src_loc_ptr,
1592
1688
  num_locs,
1593
1689
  num_locs_upper: tl.constexpr,
1690
+ BYTES_PER_TILE: tl.constexpr,
1594
1691
  ):
1595
- BLOCK_SIZE: tl.constexpr = 128
1596
-
1692
+ """2D tiled kernel. Safe for in-place copy."""
1597
1693
  bid = tl.program_id(0)
1694
+ tid = tl.program_id(1)
1695
+
1598
1696
  stride = tl.load(strides + bid)
1697
+ base_ptr = tl.load(data_ptrs + bid)
1698
+ base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
1599
1699
 
1600
- data_ptr = tl.load(data_ptrs + bid)
1601
- data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
1700
+ byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
1701
+ mask_byte = byte_off < stride
1702
+ tl.multiple_of(byte_off, 16)
1602
1703
 
1603
- num_locs_offset = tl.arange(0, num_locs_upper)
1604
- tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1605
- src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
1704
+ loc_idx = tl.arange(0, num_locs_upper)
1705
+ mask_loc = loc_idx < num_locs
1606
1706
 
1607
- # NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
1608
- # because this copy is an inplace operation.
1707
+ src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
1708
+ tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
1609
1709
 
1610
- num_loop = tl.cdiv(stride, BLOCK_SIZE)
1611
- for i in range(num_loop):
1612
- copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
1613
- mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
1614
- value = tl.load(
1615
- data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
1616
- )
1617
- tl.store(
1618
- data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
1619
- value,
1620
- mask=mask,
1621
- )
1710
+ src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
1711
+ tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
1712
+
1713
+ mask = mask_loc[:, None] & mask_byte[None, :]
1714
+ vals = tl.load(src_ptr, mask=mask)
1715
+ tl.store(tgt_ptr, vals, mask=mask)
@@ -326,6 +326,8 @@ class RadixCache(BasePrefixCache):
326
326
 
327
327
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
328
328
  all_token_len = len(token_ids)
329
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
330
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
329
331
  actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
330
332
  kv_indices = self.req_to_token_pool.req_to_token[
331
333
  req.req_pool_idx, :all_token_len
@@ -349,7 +351,8 @@ class RadixCache(BasePrefixCache):
349
351
 
350
352
  old_prefix_len = len(req.prefix_indices)
351
353
  if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
352
- # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
354
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
355
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
353
356
  old_prefix_len -= 1
354
357
 
355
358
  # Radix Cache takes one ref in memory pool
@@ -370,7 +373,8 @@ class RadixCache(BasePrefixCache):
370
373
 
371
374
  token_ids = req.fill_ids
372
375
  all_token_len = len(token_ids)
373
- # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
376
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
377
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
374
378
  actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
375
379
  kv_indices = self.req_to_token_pool.req_to_token[
376
380
  req.req_pool_idx, :all_token_len
@@ -393,7 +397,8 @@ class RadixCache(BasePrefixCache):
393
397
 
394
398
  old_prefix_len = len(req.prefix_indices)
395
399
  if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
396
- # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
400
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
401
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
397
402
  old_prefix_len -= 1
398
403
 
399
404
  # Radix Cache takes one ref in memory pool
@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
32
32
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
33
33
  from sglang.srt.mem_cache.radix_cache import (
34
34
  RadixKey,
35
+ _convert_to_bigram_key,
35
36
  _key_match_page_size1,
36
37
  _key_match_paged,
37
38
  get_child_key,
@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
327
328
  sliding_window_size: int,
328
329
  page_size: int,
329
330
  disable: bool = False,
331
+ is_eagle: bool = False,
330
332
  ):
331
333
  assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
332
334
  self.req_to_token_pool = req_to_token_pool
333
335
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
334
336
  self.page_size = page_size
335
337
  self.disable = disable
338
+ self.is_eagle = is_eagle
336
339
 
337
340
  if self.token_to_kv_pool_allocator:
338
341
  self.device = self.token_to_kv_pool_allocator.device
@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
346
349
  self.key_match_fn = partial(_key_match_paged, page_size=page_size)
347
350
  self.get_child_key_fn = partial(get_child_key, page_size=page_size)
348
351
 
352
+ if is_eagle:
353
+ self.key_convert_fn = _convert_to_bigram_key
354
+ else:
355
+ self.key_convert_fn = lambda key: key
356
+
349
357
  self.sliding_window_size = sliding_window_size
350
358
  self.reset()
351
359
 
@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
376
384
  The last node create a new child if the prefix is shorter
377
385
  than the last node's value.
378
386
  """
387
+ key.token_ids = self.key_convert_fn(key.token_ids)
388
+
379
389
  if self.disable or len(key) == 0:
380
390
  return MatchResult(
381
391
  device_indices=torch.empty(
@@ -406,8 +416,15 @@ class SWARadixCache(BasePrefixCache):
406
416
  if self.disable:
407
417
  return 0
408
418
 
419
+ key.token_ids = self.key_convert_fn(key.token_ids)
420
+
409
421
  if value is None:
410
422
  value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
423
+
424
+ if self.is_eagle:
425
+ # Make sure the value len equal to the EAGLE bigram key len
426
+ value = value[: len(key)]
427
+
411
428
  return self._insert_helper(self.root_node, key, value, prev_prefix_len)
412
429
 
413
430
  def cache_finished_req(self, req: Req) -> None:
@@ -422,25 +439,41 @@ class SWARadixCache(BasePrefixCache):
422
439
  return
423
440
 
424
441
  token_ids = (req.origin_input_ids + req.output_ids)[:-1]
442
+ all_token_len = len(token_ids)
443
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
444
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
445
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
425
446
  kv_indices = self.req_to_token_pool.req_to_token[
426
- req.req_pool_idx, : len(token_ids)
447
+ req.req_pool_idx, :all_token_len
427
448
  ]
428
449
 
429
450
  if self.page_size != 1:
430
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
451
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
431
452
  page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
432
453
  self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
433
454
  else:
434
- page_aligned_len = len(kv_indices)
455
+ page_aligned_len = actual_kv_len
435
456
  page_aligned_kv_indices = kv_indices.clone()
457
+ if self.is_eagle:
458
+ self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
459
+
460
+ page_aligned_token_len = (
461
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
462
+ )
463
+
464
+ old_prefix_len = len(req.prefix_indices)
465
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
466
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
467
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
468
+ old_prefix_len -= 1
436
469
 
437
470
  # Radix Cache takes one ref in memory pool
438
471
  # insert the token_ids and kv_indices into the radix tree
439
472
  # Note: the insert function already frees the overlapped kv_indices
440
473
  new_prefix_len = self.insert(
441
- RadixKey(token_ids[:page_aligned_len], req.extra_key),
474
+ RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
442
475
  page_aligned_kv_indices,
443
- len(req.prefix_indices),
476
+ old_prefix_len,
444
477
  )
445
478
 
446
479
  # Remove req slot release the cache lock
@@ -459,39 +492,56 @@ class SWARadixCache(BasePrefixCache):
459
492
  return
460
493
 
461
494
  token_ids = req.fill_ids
495
+ all_token_len = len(token_ids)
496
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
497
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
498
+ actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
462
499
  kv_indices = self.req_to_token_pool.req_to_token[
463
- req.req_pool_idx, : len(token_ids)
500
+ req.req_pool_idx, :all_token_len
464
501
  ]
465
502
 
466
503
  if self.page_size != 1:
467
- page_aligned_len = len(kv_indices) // self.page_size * self.page_size
504
+ page_aligned_len = actual_kv_len // self.page_size * self.page_size
468
505
  page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
469
506
  else:
470
- page_aligned_len = len(kv_indices)
507
+ page_aligned_len = actual_kv_len
471
508
  page_aligned_kv_indices = kv_indices.clone()
472
- page_aligned_token_ids = token_ids[:page_aligned_len]
509
+
510
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
511
+ page_aligned_token_len = (
512
+ page_aligned_len + 1 if self.is_eagle else page_aligned_len
513
+ )
514
+ page_aligned_token_ids = token_ids[:page_aligned_token_len]
515
+
516
+ old_prefix_len = len(req.prefix_indices)
517
+ if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
518
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
519
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
520
+ old_prefix_len -= 1
473
521
 
474
522
  # Radix Cache takes one ref in memory pool
475
523
  # Note: the insert function already frees the overlapped kv_indices
476
524
  new_prefix_len = self.insert(
477
525
  RadixKey(page_aligned_token_ids, req.extra_key),
478
526
  page_aligned_kv_indices,
479
- len(req.prefix_indices),
527
+ old_prefix_len,
480
528
  )
481
529
 
482
530
  # The prefix indices could be updated, reuse it
483
531
  new_indices, new_last_node, _, _ = self.match_prefix(
484
532
  RadixKey(page_aligned_token_ids, req.extra_key)
485
533
  )
486
- assert len(req.prefix_indices) <= len(
534
+ assert old_prefix_len <= len(
487
535
  new_indices
488
536
  ), f"{req.prefix_indices=}, {new_indices=}"
489
537
  assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
490
538
  self.req_to_token_pool.write(
491
- (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
492
- new_indices[len(req.prefix_indices) :],
539
+ (req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
540
+ new_indices[old_prefix_len:],
493
541
  )
494
542
 
543
+ req.last_matched_prefix_len = len(new_indices)
544
+
495
545
  self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
496
546
  swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
497
547
 
@@ -501,7 +551,13 @@ class SWARadixCache(BasePrefixCache):
501
551
  [new_indices, kv_indices[len(new_indices) :]]
502
552
  )
503
553
  else:
504
- req.prefix_indices = new_indices
554
+ if self.is_eagle:
555
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
556
+ req.prefix_indices = torch.cat(
557
+ [new_indices, kv_indices[actual_kv_len:]]
558
+ )
559
+ else:
560
+ req.prefix_indices = new_indices
505
561
  req.last_node = new_last_node
506
562
  req.swa_uuid_for_lock = swa_uuid_for_lock
507
563
 
@@ -849,7 +849,7 @@ class CudaGraphRunner:
849
849
  )
850
850
 
851
851
  elif self.model_runner.spec_algorithm.is_ngram():
852
- from sglang.srt.speculative.ngram_utils import NgramVerifyInput
852
+ from sglang.srt.speculative.ngram_info import NgramVerifyInput
853
853
 
854
854
  spec_info = NgramVerifyInput(
855
855
  draft_token=None,