sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,9 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
|
+
from dataclasses import dataclass
|
19
|
+
|
20
|
+
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
18
21
|
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
19
22
|
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
20
23
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -109,17 +112,38 @@ class ReqToTokenPool:
|
|
109
112
|
|
110
113
|
|
111
114
|
class MambaPool:
|
115
|
+
@dataclass(frozen=True, kw_only=True)
|
116
|
+
class State:
|
117
|
+
conv: torch.Tensor
|
118
|
+
temporal: torch.Tensor
|
119
|
+
|
120
|
+
def at_layer_idx(self, layer: int):
|
121
|
+
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
|
122
|
+
|
123
|
+
def mem_usage_bytes(self):
|
124
|
+
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
|
125
|
+
|
126
|
+
@dataclass(frozen=True, kw_only=True)
|
127
|
+
class SpeculativeState(State):
|
128
|
+
intermediate_ssm: torch.Tensor
|
129
|
+
intermediate_conv_window: torch.Tensor
|
130
|
+
|
112
131
|
def __init__(
|
113
132
|
self,
|
133
|
+
*,
|
114
134
|
size: int,
|
115
|
-
|
116
|
-
ssm_dtype: torch.dtype,
|
117
|
-
num_mamba_layers: int,
|
118
|
-
conv_state_shape: Tuple[int, int],
|
119
|
-
temporal_state_shape: Tuple[int, int],
|
135
|
+
cache_params: "Mamba2CacheParams",
|
120
136
|
device: str,
|
121
137
|
speculative_num_draft_tokens: Optional[int] = None,
|
122
138
|
):
|
139
|
+
conv_state_shape = cache_params.shape.conv
|
140
|
+
temporal_state_shape = cache_params.shape.temporal
|
141
|
+
conv_dtype = cache_params.dtype.conv
|
142
|
+
ssm_dtype = cache_params.dtype.temporal
|
143
|
+
num_mamba_layers = len(cache_params.layers)
|
144
|
+
|
145
|
+
# assume conv_state = (dim, state_len)
|
146
|
+
assert conv_state_shape[0] > conv_state_shape[1]
|
123
147
|
conv_state = torch.zeros(
|
124
148
|
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
125
149
|
dtype=conv_dtype,
|
@@ -158,11 +182,11 @@ class MambaPool:
|
|
158
182
|
dtype=conv_dtype,
|
159
183
|
device="cuda",
|
160
184
|
)
|
161
|
-
self.mamba_cache = (
|
162
|
-
conv_state,
|
163
|
-
temporal_state,
|
164
|
-
intermediate_ssm_state_cache,
|
165
|
-
intermediate_conv_window_cache,
|
185
|
+
self.mamba_cache = self.SpeculativeState(
|
186
|
+
conv=conv_state,
|
187
|
+
temporal=temporal_state,
|
188
|
+
intermediate_ssm=intermediate_ssm_state_cache,
|
189
|
+
intermediate_conv_window=intermediate_conv_window_cache,
|
166
190
|
)
|
167
191
|
logger.info(
|
168
192
|
f"Mamba Cache is allocated. "
|
@@ -172,7 +196,7 @@ class MambaPool:
|
|
172
196
|
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
173
197
|
)
|
174
198
|
else:
|
175
|
-
self.mamba_cache = (conv_state, temporal_state)
|
199
|
+
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
176
200
|
logger.info(
|
177
201
|
f"Mamba Cache is allocated. "
|
178
202
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
@@ -180,16 +204,14 @@ class MambaPool:
|
|
180
204
|
)
|
181
205
|
self.size = size
|
182
206
|
self.free_slots = list(range(size))
|
183
|
-
self.mem_usage = self.
|
184
|
-
|
185
|
-
def get_mamba_params_all_layers(self):
|
186
|
-
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
207
|
+
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
187
208
|
|
188
|
-
def
|
189
|
-
|
209
|
+
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
210
|
+
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
211
|
+
return self.mamba_cache
|
190
212
|
|
191
|
-
def
|
192
|
-
return
|
213
|
+
def mamba2_layer_cache(self, layer_id: int):
|
214
|
+
return self.mamba_cache.at_layer_idx(layer_id)
|
193
215
|
|
194
216
|
def available_size(self):
|
195
217
|
return len(self.free_slots)
|
@@ -208,7 +230,9 @@ class MambaPool:
|
|
208
230
|
self.free_slots.append(free_index)
|
209
231
|
else:
|
210
232
|
self.free_slots.extend(free_index)
|
211
|
-
self.mamba_cache[
|
233
|
+
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
|
234
|
+
:, free_index
|
235
|
+
] = 0
|
212
236
|
|
213
237
|
def clear(self):
|
214
238
|
self.free_slots = list(range(self.size))
|
@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
219
243
|
|
220
244
|
def __init__(
|
221
245
|
self,
|
246
|
+
*,
|
222
247
|
size: int,
|
223
248
|
max_context_len: int,
|
224
249
|
device: str,
|
225
250
|
enable_memory_saver: bool,
|
226
|
-
|
227
|
-
|
228
|
-
mamba_layers: List[int],
|
229
|
-
conv_state_shape: Tuple[int, int],
|
230
|
-
temporal_state_shape: Tuple[int, int],
|
231
|
-
speculative_num_draft_tokens: int,
|
251
|
+
cache_params: "Mamba2CacheParams",
|
252
|
+
speculative_num_draft_tokens: int = None,
|
232
253
|
):
|
233
254
|
super().__init__(
|
234
255
|
size=size,
|
@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
238
259
|
)
|
239
260
|
|
240
261
|
self.mamba_pool = MambaPool(
|
241
|
-
size,
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
conv_state_shape,
|
246
|
-
temporal_state_shape,
|
247
|
-
device,
|
248
|
-
speculative_num_draft_tokens,
|
262
|
+
size=size,
|
263
|
+
cache_params=cache_params,
|
264
|
+
device=device,
|
265
|
+
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
249
266
|
)
|
250
|
-
self.mamba_map = {layer_id: i for i, layer_id in enumerate(
|
267
|
+
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
251
268
|
|
252
269
|
self.device = device
|
253
270
|
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
|
287
304
|
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
288
305
|
return self.req_index_to_mamba_index_mapping[req_indices]
|
289
306
|
|
290
|
-
def
|
307
|
+
def mamba2_layer_cache(self, layer_id: int):
|
291
308
|
assert layer_id in self.mamba_map
|
292
|
-
return self.mamba_pool.
|
309
|
+
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
|
293
310
|
|
294
|
-
def
|
295
|
-
return self.mamba_pool.
|
311
|
+
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
|
312
|
+
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
|
296
313
|
|
297
314
|
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
298
315
|
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
@@ -415,6 +432,7 @@ class MHATokenToKVPool(KVCache):
|
|
415
432
|
enable_memory_saver: bool,
|
416
433
|
start_layer: Optional[int] = None,
|
417
434
|
end_layer: Optional[int] = None,
|
435
|
+
enable_kv_cache_copy: bool = False,
|
418
436
|
):
|
419
437
|
super().__init__(
|
420
438
|
size,
|
@@ -446,8 +464,57 @@ class MHATokenToKVPool(KVCache):
|
|
446
464
|
|
447
465
|
self.device_module = torch.get_device_module(self.device)
|
448
466
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
467
|
+
|
468
|
+
if enable_kv_cache_copy:
|
469
|
+
self._init_kv_copy_and_warmup()
|
470
|
+
else:
|
471
|
+
self._kv_copy_config = None
|
472
|
+
|
449
473
|
self._finalize_allocation_log(size)
|
450
474
|
|
475
|
+
def _init_kv_copy_and_warmup(self):
|
476
|
+
# Heuristics for KV copy tiling
|
477
|
+
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
|
478
|
+
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
|
479
|
+
_KV_COPY_TILE_SIZE_LARGE = 512
|
480
|
+
_KV_COPY_TILE_SIZE_MEDIUM = 256
|
481
|
+
_KV_COPY_TILE_SIZE_SMALL = 128
|
482
|
+
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
|
483
|
+
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
|
484
|
+
|
485
|
+
stride_bytes = int(self.data_strides[0].item())
|
486
|
+
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
|
487
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
|
488
|
+
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
|
489
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
|
490
|
+
else:
|
491
|
+
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
|
492
|
+
|
493
|
+
self._kv_copy_config = {
|
494
|
+
"bytes_per_tile": bytes_per_tile,
|
495
|
+
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
|
496
|
+
"num_warps": (
|
497
|
+
_KV_COPY_NUM_WARPS_SMALL_TILE
|
498
|
+
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
|
499
|
+
else _KV_COPY_NUM_WARPS_LARGE_TILE
|
500
|
+
),
|
501
|
+
}
|
502
|
+
|
503
|
+
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
|
504
|
+
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
|
505
|
+
|
506
|
+
copy_all_layer_kv_cache_tiled[grid](
|
507
|
+
self.data_ptrs,
|
508
|
+
self.data_strides,
|
509
|
+
dummy_loc,
|
510
|
+
dummy_loc,
|
511
|
+
1,
|
512
|
+
1,
|
513
|
+
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
|
514
|
+
num_warps=self._kv_copy_config["num_warps"],
|
515
|
+
num_stages=2,
|
516
|
+
)
|
517
|
+
|
451
518
|
def _create_buffers(self):
|
452
519
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
453
520
|
with (
|
@@ -642,13 +709,28 @@ class MHATokenToKVPool(KVCache):
|
|
642
709
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
643
710
|
|
644
711
|
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
645
|
-
|
712
|
+
N = tgt_loc.numel()
|
713
|
+
if N == 0:
|
714
|
+
return
|
715
|
+
|
716
|
+
assert (
|
717
|
+
self._kv_copy_config is not None
|
718
|
+
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
|
719
|
+
|
720
|
+
cfg = self._kv_copy_config
|
721
|
+
N_upper = next_power_of_2(N)
|
722
|
+
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
|
723
|
+
|
724
|
+
copy_all_layer_kv_cache_tiled[grid](
|
646
725
|
self.data_ptrs,
|
647
726
|
self.data_strides,
|
648
727
|
tgt_loc,
|
649
728
|
src_loc,
|
650
|
-
|
651
|
-
|
729
|
+
N,
|
730
|
+
N_upper,
|
731
|
+
BYTES_PER_TILE=cfg["bytes_per_tile"],
|
732
|
+
num_warps=cfg["num_warps"],
|
733
|
+
num_stages=2,
|
652
734
|
)
|
653
735
|
|
654
736
|
|
@@ -749,6 +831,7 @@ class SWAKVPool(KVCache):
|
|
749
831
|
self,
|
750
832
|
size: int,
|
751
833
|
size_swa: int,
|
834
|
+
dtype: torch.dtype,
|
752
835
|
swa_attention_layer_ids: List[int],
|
753
836
|
full_attention_layer_ids: List[int],
|
754
837
|
enable_kvcache_transpose: bool,
|
@@ -757,6 +840,7 @@ class SWAKVPool(KVCache):
|
|
757
840
|
):
|
758
841
|
self.size = size
|
759
842
|
self.size_swa = size_swa
|
843
|
+
self.dtype = dtype
|
760
844
|
self.swa_layer_nums = len(swa_attention_layer_ids)
|
761
845
|
self.full_layer_nums = len(full_attention_layer_ids)
|
762
846
|
kwargs["page_size"] = 1
|
@@ -766,11 +850,13 @@ class SWAKVPool(KVCache):
|
|
766
850
|
|
767
851
|
self.swa_kv_pool = token_to_kv_pool_class(
|
768
852
|
size=size_swa,
|
853
|
+
dtype=dtype,
|
769
854
|
layer_num=self.swa_layer_nums,
|
770
855
|
**kwargs,
|
771
856
|
)
|
772
857
|
self.full_kv_pool = token_to_kv_pool_class(
|
773
858
|
size=size,
|
859
|
+
dtype=dtype,
|
774
860
|
layer_num=self.full_layer_nums,
|
775
861
|
**kwargs,
|
776
862
|
)
|
@@ -1091,7 +1177,9 @@ class MLATokenToKVPool(KVCache):
|
|
1091
1177
|
dtype=torch.uint64,
|
1092
1178
|
device=self.device,
|
1093
1179
|
)
|
1094
|
-
|
1180
|
+
if not use_nsa:
|
1181
|
+
# NSA will allocate indexer KV cache later and then log the total size
|
1182
|
+
self._finalize_allocation_log(size)
|
1095
1183
|
|
1096
1184
|
def get_kv_size_bytes(self):
|
1097
1185
|
assert hasattr(self, "kv_buffer")
|
@@ -1212,6 +1300,9 @@ class MLATokenToKVPool(KVCache):
|
|
1212
1300
|
|
1213
1301
|
|
1214
1302
|
class NSATokenToKVPool(MLATokenToKVPool):
|
1303
|
+
quant_block_size = 128
|
1304
|
+
index_k_with_scale_buffer_dtype = torch.uint8
|
1305
|
+
|
1215
1306
|
def __init__(
|
1216
1307
|
self,
|
1217
1308
|
size: int,
|
@@ -1245,8 +1336,6 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|
1245
1336
|
# num head == 1 and head dim == 128 for index_k in NSA
|
1246
1337
|
assert index_head_dim == 128
|
1247
1338
|
|
1248
|
-
self.quant_block_size = 128
|
1249
|
-
|
1250
1339
|
assert self.page_size == 64
|
1251
1340
|
self.index_k_with_scale_buffer = [
|
1252
1341
|
torch.zeros(
|
@@ -1261,11 +1350,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|
1261
1350
|
self.page_size
|
1262
1351
|
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
|
1263
1352
|
),
|
1264
|
-
dtype=
|
1353
|
+
dtype=self.index_k_with_scale_buffer_dtype,
|
1265
1354
|
device=device,
|
1266
1355
|
)
|
1267
1356
|
for _ in range(layer_num)
|
1268
1357
|
]
|
1358
|
+
self._finalize_allocation_log(size)
|
1269
1359
|
|
1270
1360
|
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
|
1271
1361
|
if self.layer_transfer_counter is not None:
|
@@ -1307,6 +1397,12 @@ class NSATokenToKVPool(MLATokenToKVPool):
|
|
1307
1397
|
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
|
1308
1398
|
)
|
1309
1399
|
|
1400
|
+
def get_kv_size_bytes(self):
|
1401
|
+
kv_size_bytes = super().get_kv_size_bytes()
|
1402
|
+
for index_k_cache in self.index_k_with_scale_buffer:
|
1403
|
+
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
|
1404
|
+
return kv_size_bytes
|
1405
|
+
|
1310
1406
|
|
1311
1407
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
1312
1408
|
def __init__(
|
@@ -1584,38 +1680,36 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
1584
1680
|
|
1585
1681
|
|
1586
1682
|
@triton.jit
|
1587
|
-
def
|
1683
|
+
def copy_all_layer_kv_cache_tiled(
|
1588
1684
|
data_ptrs,
|
1589
1685
|
strides,
|
1590
1686
|
tgt_loc_ptr,
|
1591
1687
|
src_loc_ptr,
|
1592
1688
|
num_locs,
|
1593
1689
|
num_locs_upper: tl.constexpr,
|
1690
|
+
BYTES_PER_TILE: tl.constexpr,
|
1594
1691
|
):
|
1595
|
-
|
1596
|
-
|
1692
|
+
"""2D tiled kernel. Safe for in-place copy."""
|
1597
1693
|
bid = tl.program_id(0)
|
1694
|
+
tid = tl.program_id(1)
|
1695
|
+
|
1598
1696
|
stride = tl.load(strides + bid)
|
1697
|
+
base_ptr = tl.load(data_ptrs + bid)
|
1698
|
+
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
|
1599
1699
|
|
1600
|
-
|
1601
|
-
|
1700
|
+
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
|
1701
|
+
mask_byte = byte_off < stride
|
1702
|
+
tl.multiple_of(byte_off, 16)
|
1602
1703
|
|
1603
|
-
|
1604
|
-
|
1605
|
-
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
1704
|
+
loc_idx = tl.arange(0, num_locs_upper)
|
1705
|
+
mask_loc = loc_idx < num_locs
|
1606
1706
|
|
1607
|
-
|
1608
|
-
|
1707
|
+
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
1708
|
+
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
|
1609
1709
|
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1616
|
-
)
|
1617
|
-
tl.store(
|
1618
|
-
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
1619
|
-
value,
|
1620
|
-
mask=mask,
|
1621
|
-
)
|
1710
|
+
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
|
1711
|
+
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
|
1712
|
+
|
1713
|
+
mask = mask_loc[:, None] & mask_byte[None, :]
|
1714
|
+
vals = tl.load(src_ptr, mask=mask)
|
1715
|
+
tl.store(tgt_ptr, vals, mask=mask)
|
@@ -326,6 +326,8 @@ class RadixCache(BasePrefixCache):
|
|
326
326
|
|
327
327
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
328
328
|
all_token_len = len(token_ids)
|
329
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
330
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
329
331
|
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
330
332
|
kv_indices = self.req_to_token_pool.req_to_token[
|
331
333
|
req.req_pool_idx, :all_token_len
|
@@ -349,7 +351,8 @@ class RadixCache(BasePrefixCache):
|
|
349
351
|
|
350
352
|
old_prefix_len = len(req.prefix_indices)
|
351
353
|
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
352
|
-
#
|
354
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
355
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
353
356
|
old_prefix_len -= 1
|
354
357
|
|
355
358
|
# Radix Cache takes one ref in memory pool
|
@@ -370,7 +373,8 @@ class RadixCache(BasePrefixCache):
|
|
370
373
|
|
371
374
|
token_ids = req.fill_ids
|
372
375
|
all_token_len = len(token_ids)
|
373
|
-
#
|
376
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
377
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
374
378
|
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
375
379
|
kv_indices = self.req_to_token_pool.req_to_token[
|
376
380
|
req.req_pool_idx, :all_token_len
|
@@ -393,7 +397,8 @@ class RadixCache(BasePrefixCache):
|
|
393
397
|
|
394
398
|
old_prefix_len = len(req.prefix_indices)
|
395
399
|
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
396
|
-
#
|
400
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
401
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
397
402
|
old_prefix_len -= 1
|
398
403
|
|
399
404
|
# Radix Cache takes one ref in memory pool
|
@@ -32,6 +32,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
|
32
32
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
33
33
|
from sglang.srt.mem_cache.radix_cache import (
|
34
34
|
RadixKey,
|
35
|
+
_convert_to_bigram_key,
|
35
36
|
_key_match_page_size1,
|
36
37
|
_key_match_paged,
|
37
38
|
get_child_key,
|
@@ -327,12 +328,14 @@ class SWARadixCache(BasePrefixCache):
|
|
327
328
|
sliding_window_size: int,
|
328
329
|
page_size: int,
|
329
330
|
disable: bool = False,
|
331
|
+
is_eagle: bool = False,
|
330
332
|
):
|
331
333
|
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
332
334
|
self.req_to_token_pool = req_to_token_pool
|
333
335
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
334
336
|
self.page_size = page_size
|
335
337
|
self.disable = disable
|
338
|
+
self.is_eagle = is_eagle
|
336
339
|
|
337
340
|
if self.token_to_kv_pool_allocator:
|
338
341
|
self.device = self.token_to_kv_pool_allocator.device
|
@@ -346,6 +349,11 @@ class SWARadixCache(BasePrefixCache):
|
|
346
349
|
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
347
350
|
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
348
351
|
|
352
|
+
if is_eagle:
|
353
|
+
self.key_convert_fn = _convert_to_bigram_key
|
354
|
+
else:
|
355
|
+
self.key_convert_fn = lambda key: key
|
356
|
+
|
349
357
|
self.sliding_window_size = sliding_window_size
|
350
358
|
self.reset()
|
351
359
|
|
@@ -376,6 +384,8 @@ class SWARadixCache(BasePrefixCache):
|
|
376
384
|
The last node create a new child if the prefix is shorter
|
377
385
|
than the last node's value.
|
378
386
|
"""
|
387
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
388
|
+
|
379
389
|
if self.disable or len(key) == 0:
|
380
390
|
return MatchResult(
|
381
391
|
device_indices=torch.empty(
|
@@ -406,8 +416,15 @@ class SWARadixCache(BasePrefixCache):
|
|
406
416
|
if self.disable:
|
407
417
|
return 0
|
408
418
|
|
419
|
+
key.token_ids = self.key_convert_fn(key.token_ids)
|
420
|
+
|
409
421
|
if value is None:
|
410
422
|
value = torch.tensor([x for x in key.token_ids], dtype=torch.int64)
|
423
|
+
|
424
|
+
if self.is_eagle:
|
425
|
+
# Make sure the value len equal to the EAGLE bigram key len
|
426
|
+
value = value[: len(key)]
|
427
|
+
|
411
428
|
return self._insert_helper(self.root_node, key, value, prev_prefix_len)
|
412
429
|
|
413
430
|
def cache_finished_req(self, req: Req) -> None:
|
@@ -422,25 +439,41 @@ class SWARadixCache(BasePrefixCache):
|
|
422
439
|
return
|
423
440
|
|
424
441
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
442
|
+
all_token_len = len(token_ids)
|
443
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
444
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
445
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
425
446
|
kv_indices = self.req_to_token_pool.req_to_token[
|
426
|
-
req.req_pool_idx, :
|
447
|
+
req.req_pool_idx, :all_token_len
|
427
448
|
]
|
428
449
|
|
429
450
|
if self.page_size != 1:
|
430
|
-
page_aligned_len =
|
451
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
431
452
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
432
453
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
433
454
|
else:
|
434
|
-
page_aligned_len =
|
455
|
+
page_aligned_len = actual_kv_len
|
435
456
|
page_aligned_kv_indices = kv_indices.clone()
|
457
|
+
if self.is_eagle:
|
458
|
+
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
459
|
+
|
460
|
+
page_aligned_token_len = (
|
461
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
462
|
+
)
|
463
|
+
|
464
|
+
old_prefix_len = len(req.prefix_indices)
|
465
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
466
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
467
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
468
|
+
old_prefix_len -= 1
|
436
469
|
|
437
470
|
# Radix Cache takes one ref in memory pool
|
438
471
|
# insert the token_ids and kv_indices into the radix tree
|
439
472
|
# Note: the insert function already frees the overlapped kv_indices
|
440
473
|
new_prefix_len = self.insert(
|
441
|
-
RadixKey(token_ids[:
|
474
|
+
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
442
475
|
page_aligned_kv_indices,
|
443
|
-
|
476
|
+
old_prefix_len,
|
444
477
|
)
|
445
478
|
|
446
479
|
# Remove req slot release the cache lock
|
@@ -459,39 +492,56 @@ class SWARadixCache(BasePrefixCache):
|
|
459
492
|
return
|
460
493
|
|
461
494
|
token_ids = req.fill_ids
|
495
|
+
all_token_len = len(token_ids)
|
496
|
+
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
|
497
|
+
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
|
498
|
+
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
462
499
|
kv_indices = self.req_to_token_pool.req_to_token[
|
463
|
-
req.req_pool_idx, :
|
500
|
+
req.req_pool_idx, :all_token_len
|
464
501
|
]
|
465
502
|
|
466
503
|
if self.page_size != 1:
|
467
|
-
page_aligned_len =
|
504
|
+
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
468
505
|
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
469
506
|
else:
|
470
|
-
page_aligned_len =
|
507
|
+
page_aligned_len = actual_kv_len
|
471
508
|
page_aligned_kv_indices = kv_indices.clone()
|
472
|
-
|
509
|
+
|
510
|
+
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
511
|
+
page_aligned_token_len = (
|
512
|
+
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
513
|
+
)
|
514
|
+
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
515
|
+
|
516
|
+
old_prefix_len = len(req.prefix_indices)
|
517
|
+
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
518
|
+
# In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
|
519
|
+
# Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
|
520
|
+
old_prefix_len -= 1
|
473
521
|
|
474
522
|
# Radix Cache takes one ref in memory pool
|
475
523
|
# Note: the insert function already frees the overlapped kv_indices
|
476
524
|
new_prefix_len = self.insert(
|
477
525
|
RadixKey(page_aligned_token_ids, req.extra_key),
|
478
526
|
page_aligned_kv_indices,
|
479
|
-
|
527
|
+
old_prefix_len,
|
480
528
|
)
|
481
529
|
|
482
530
|
# The prefix indices could be updated, reuse it
|
483
531
|
new_indices, new_last_node, _, _ = self.match_prefix(
|
484
532
|
RadixKey(page_aligned_token_ids, req.extra_key)
|
485
533
|
)
|
486
|
-
assert
|
534
|
+
assert old_prefix_len <= len(
|
487
535
|
new_indices
|
488
536
|
), f"{req.prefix_indices=}, {new_indices=}"
|
489
537
|
assert new_prefix_len <= len(new_indices), f"{new_prefix_len=}, {new_indices=}"
|
490
538
|
self.req_to_token_pool.write(
|
491
|
-
(req.req_pool_idx, slice(
|
492
|
-
new_indices[
|
539
|
+
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
540
|
+
new_indices[old_prefix_len:],
|
493
541
|
)
|
494
542
|
|
543
|
+
req.last_matched_prefix_len = len(new_indices)
|
544
|
+
|
495
545
|
self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
496
546
|
swa_uuid_for_lock = self.inc_lock_ref(new_last_node)
|
497
547
|
|
@@ -501,7 +551,13 @@ class SWARadixCache(BasePrefixCache):
|
|
501
551
|
[new_indices, kv_indices[len(new_indices) :]]
|
502
552
|
)
|
503
553
|
else:
|
504
|
-
|
554
|
+
if self.is_eagle:
|
555
|
+
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
556
|
+
req.prefix_indices = torch.cat(
|
557
|
+
[new_indices, kv_indices[actual_kv_len:]]
|
558
|
+
)
|
559
|
+
else:
|
560
|
+
req.prefix_indices = new_indices
|
505
561
|
req.last_node = new_last_node
|
506
562
|
req.swa_uuid_for_lock = swa_uuid_for_lock
|
507
563
|
|
@@ -849,7 +849,7 @@ class CudaGraphRunner:
|
|
849
849
|
)
|
850
850
|
|
851
851
|
elif self.model_runner.spec_algorithm.is_ngram():
|
852
|
-
from sglang.srt.speculative.
|
852
|
+
from sglang.srt.speculative.ngram_info import NgramVerifyInput
|
853
853
|
|
854
854
|
spec_info = NgramVerifyInput(
|
855
855
|
draft_token=None,
|