sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
30
30
|
from sglang.srt.managers.io_struct import (
|
31
31
|
GetWeightsByNameReqInput,
|
32
32
|
InitWeightsUpdateGroupReqInput,
|
33
|
+
LoadLoRAAdapterReqInput,
|
34
|
+
UnloadLoRAAdapterReqInput,
|
33
35
|
UpdateWeightFromDiskReqInput,
|
34
36
|
UpdateWeightsFromDistributedReqInput,
|
35
37
|
UpdateWeightsFromTensorReqInput,
|
@@ -257,7 +259,7 @@ class TpModelWorker:
|
|
257
259
|
self, recv_req: UpdateWeightsFromDistributedReqInput
|
258
260
|
):
|
259
261
|
success, message = self.model_runner.update_weights_from_distributed(
|
260
|
-
recv_req.
|
262
|
+
recv_req.names, recv_req.dtypes, recv_req.shapes, recv_req.group_name
|
261
263
|
)
|
262
264
|
return success, message
|
263
265
|
|
@@ -275,3 +277,13 @@ class TpModelWorker:
|
|
275
277
|
recv_req.name, recv_req.truncate_size
|
276
278
|
)
|
277
279
|
return parameter
|
280
|
+
|
281
|
+
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
282
|
+
result = self.model_runner.load_lora_adapter(
|
283
|
+
recv_req.lora_name, recv_req.lora_path
|
284
|
+
)
|
285
|
+
return result
|
286
|
+
|
287
|
+
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
288
|
+
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
|
289
|
+
return result
|
@@ -26,6 +26,8 @@ import torch
|
|
26
26
|
from sglang.srt.managers.io_struct import (
|
27
27
|
GetWeightsByNameReqInput,
|
28
28
|
InitWeightsUpdateGroupReqInput,
|
29
|
+
LoadLoRAAdapterReqInput,
|
30
|
+
UnloadLoRAAdapterReqInput,
|
29
31
|
UpdateWeightFromDiskReqInput,
|
30
32
|
UpdateWeightsFromDistributedReqInput,
|
31
33
|
UpdateWeightsFromTensorReqInput,
|
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
|
|
268
270
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
269
271
|
return self.worker.get_weights_by_name(recv_req)
|
270
272
|
|
273
|
+
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
274
|
+
return self.worker.load_lora_adapter(recv_req)
|
275
|
+
|
276
|
+
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
277
|
+
return self.worker.unload_lora_adapter(recv_req)
|
278
|
+
|
271
279
|
def __delete__(self):
|
272
280
|
self.input_queue.put((None, None))
|
273
281
|
self.copy_queue.put((None, None, None))
|
@@ -20,12 +20,14 @@ Page-aligned memory pool.
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
import abc
|
23
|
+
import weakref
|
23
24
|
from typing import TYPE_CHECKING
|
24
25
|
|
25
26
|
import torch
|
26
27
|
import triton
|
27
28
|
import triton.language as tl
|
28
29
|
|
30
|
+
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
29
31
|
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
30
32
|
|
31
33
|
if TYPE_CHECKING:
|
@@ -55,6 +57,11 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
55
57
|
def debug_print(self) -> str:
|
56
58
|
return ""
|
57
59
|
|
60
|
+
def log_usage(self, evictable_size: int = 0):
|
61
|
+
num_used = self.size - (self.available_size() + evictable_size)
|
62
|
+
msg = f"#token: {num_used}, token usage: {num_used / self.size:.2f}, "
|
63
|
+
return msg, num_used
|
64
|
+
|
58
65
|
def available_size(self):
|
59
66
|
return len(self.free_pages) * self.page_size
|
60
67
|
|
@@ -146,6 +153,128 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
146
153
|
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
147
154
|
|
148
155
|
|
156
|
+
class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
157
|
+
"""Allocator for SWA hybrid KV cache."""
|
158
|
+
|
159
|
+
def __init__(
|
160
|
+
self,
|
161
|
+
size: int,
|
162
|
+
size_swa: int,
|
163
|
+
dtype: torch.dtype,
|
164
|
+
device: str,
|
165
|
+
kvcache: SWAKVPool,
|
166
|
+
):
|
167
|
+
super().__init__(size, 1, dtype, device, kvcache)
|
168
|
+
assert isinstance(kvcache, SWAKVPool)
|
169
|
+
self._size_full = size
|
170
|
+
self._size_swa = size_swa
|
171
|
+
self.full_attn_allocator = TokenToKVPoolAllocator(
|
172
|
+
size,
|
173
|
+
dtype,
|
174
|
+
device,
|
175
|
+
kvcache.full_kv_pool,
|
176
|
+
)
|
177
|
+
self.swa_attn_allocator = TokenToKVPoolAllocator(
|
178
|
+
size_swa,
|
179
|
+
dtype,
|
180
|
+
device,
|
181
|
+
kvcache.swa_kv_pool,
|
182
|
+
)
|
183
|
+
self.full_to_swa_index_mapping = torch.empty(
|
184
|
+
size + size_swa + 1,
|
185
|
+
dtype=torch.int64,
|
186
|
+
device=device,
|
187
|
+
)
|
188
|
+
self.clear()
|
189
|
+
|
190
|
+
self._kvcache.full_to_swa_index_mapping = self.full_to_swa_index_mapping
|
191
|
+
|
192
|
+
def available_size(self):
|
193
|
+
return min(self.full_available_size(), self.swa_available_size())
|
194
|
+
|
195
|
+
def full_available_size(self):
|
196
|
+
return self.full_attn_allocator.available_size()
|
197
|
+
|
198
|
+
def swa_available_size(self):
|
199
|
+
return self.swa_attn_allocator.available_size()
|
200
|
+
|
201
|
+
@property
|
202
|
+
def size_full(self):
|
203
|
+
return self._size_full
|
204
|
+
|
205
|
+
@property
|
206
|
+
def size_swa(self):
|
207
|
+
return self._size_swa
|
208
|
+
|
209
|
+
def debug_print(self) -> str:
|
210
|
+
msg = ""
|
211
|
+
msg += f"#swa-available-size: {self.swa_attn_allocator.available_size()}, "
|
212
|
+
msg += (
|
213
|
+
f"#full-attn-available-size: {self.full_attn_allocator.available_size()}, "
|
214
|
+
)
|
215
|
+
return msg
|
216
|
+
|
217
|
+
def log_usage(self, swa_evictable_size: int = 0, full_evictable_size: int = 0):
|
218
|
+
used_full = self.size_full - (self.full_available_size() + full_evictable_size)
|
219
|
+
used_swa = self.size_swa - (self.swa_available_size() + swa_evictable_size)
|
220
|
+
msg = (
|
221
|
+
f"#token: full={used_full}, swa={used_swa}, "
|
222
|
+
f"token usage: full={used_full / self.size_full:.2f}, "
|
223
|
+
f"swa={used_swa / self.size_swa:.2f}, "
|
224
|
+
)
|
225
|
+
return msg, used_full
|
226
|
+
|
227
|
+
def get_kvcache(self):
|
228
|
+
return self._kvcache
|
229
|
+
|
230
|
+
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
|
231
|
+
assert self.full_to_swa_index_mapping is not None
|
232
|
+
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
|
233
|
+
|
234
|
+
def alloc(self, need_size: int):
|
235
|
+
if need_size > self.full_attn_allocator.available_size():
|
236
|
+
return None
|
237
|
+
if need_size > self.swa_attn_allocator.available_size():
|
238
|
+
return None
|
239
|
+
|
240
|
+
alloc_full_indices = self.full_attn_allocator.alloc(need_size)
|
241
|
+
alloc_swa_indices = self.swa_attn_allocator.alloc(need_size)
|
242
|
+
self.full_to_swa_index_mapping[alloc_full_indices] = alloc_swa_indices
|
243
|
+
return alloc_full_indices
|
244
|
+
|
245
|
+
def free(self, free_index: torch.Tensor):
|
246
|
+
if free_index.numel() == 0:
|
247
|
+
return
|
248
|
+
if self.is_not_in_free_group:
|
249
|
+
self.full_attn_allocator.free(free_index)
|
250
|
+
self.free_swa(free_index)
|
251
|
+
else:
|
252
|
+
self.free_group.append(free_index)
|
253
|
+
assert (
|
254
|
+
self.full_attn_allocator.available_size() <= self.full_attn_allocator.size
|
255
|
+
)
|
256
|
+
assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size
|
257
|
+
|
258
|
+
def free_swa(self, free_index: torch.Tensor):
|
259
|
+
swa_indices = self.full_to_swa_index_mapping[free_index]
|
260
|
+
swa_indices = swa_indices[swa_indices > 0]
|
261
|
+
self.swa_attn_allocator.free(swa_indices)
|
262
|
+
self.full_to_swa_index_mapping[free_index] = 0
|
263
|
+
|
264
|
+
def backup_state(self):
|
265
|
+
raise NotImplementedError
|
266
|
+
|
267
|
+
def restore_state(self, state):
|
268
|
+
raise NotImplementedError
|
269
|
+
|
270
|
+
def clear(self):
|
271
|
+
self.swa_attn_allocator.clear()
|
272
|
+
self.full_attn_allocator.clear()
|
273
|
+
self.full_to_swa_index_mapping.fill_(0)
|
274
|
+
self.is_in_free_group = False
|
275
|
+
self.free_group = []
|
276
|
+
|
277
|
+
|
149
278
|
@triton.jit
|
150
279
|
def alloc_extend_kernel(
|
151
280
|
pre_lens_ptr,
|
@@ -411,3 +540,164 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
411
540
|
)
|
412
541
|
self.is_not_in_free_group = True
|
413
542
|
self.free_group = []
|
543
|
+
|
544
|
+
|
545
|
+
def alloc_extend_kernel_ascend(
|
546
|
+
prefix_lens,
|
547
|
+
seq_lens,
|
548
|
+
last_loc,
|
549
|
+
free_pages,
|
550
|
+
out_indices,
|
551
|
+
page_size,
|
552
|
+
device,
|
553
|
+
):
|
554
|
+
extend_lens = seq_lens - prefix_lens
|
555
|
+
end_pos = torch.cumsum(extend_lens, 0)
|
556
|
+
start_pos = end_pos - extend_lens
|
557
|
+
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
558
|
+
prefix_lens + page_size - 1
|
559
|
+
) // page_size
|
560
|
+
num_full_new_pages = (seq_lens) // page_size - (
|
561
|
+
prefix_lens + page_size - 1
|
562
|
+
) // page_size
|
563
|
+
need_page = num_new_pages - num_full_new_pages
|
564
|
+
end_new_pages = torch.cumsum(num_new_pages, 0)
|
565
|
+
start_new_pages = end_new_pages - num_new_pages
|
566
|
+
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
|
567
|
+
for i in range(len(prefix_lens)):
|
568
|
+
num1 = (
|
569
|
+
min(
|
570
|
+
seq_lens[i],
|
571
|
+
(prefix_lens[i] + page_size - 1) // page_size * page_size,
|
572
|
+
)
|
573
|
+
- prefix_lens[i]
|
574
|
+
)
|
575
|
+
if num1:
|
576
|
+
out_indices[start_pos[i] : start_pos[i] + num1] = (
|
577
|
+
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
|
578
|
+
)
|
579
|
+
|
580
|
+
num2 = (
|
581
|
+
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
|
582
|
+
) * page_size
|
583
|
+
if num2:
|
584
|
+
pages = (
|
585
|
+
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
|
586
|
+
* page_size
|
587
|
+
)
|
588
|
+
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
|
589
|
+
pages.view(-1, 1) + pos_in_page.view(1, -1)
|
590
|
+
).view(-1)
|
591
|
+
|
592
|
+
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
|
593
|
+
if num3:
|
594
|
+
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
595
|
+
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
596
|
+
).view(-1)
|
597
|
+
return num_new_pages
|
598
|
+
|
599
|
+
|
600
|
+
def alloc_decode_kernel_ascend(
|
601
|
+
seq_lens,
|
602
|
+
last_loc,
|
603
|
+
free_pages,
|
604
|
+
out_indices,
|
605
|
+
page_size,
|
606
|
+
):
|
607
|
+
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
608
|
+
seq_lens - 1 + page_size - 1
|
609
|
+
) // page_size
|
610
|
+
end_new_pages = torch.cumsum(num_new_pages, 0)
|
611
|
+
start_new_pages = end_new_pages - num_new_pages
|
612
|
+
for i in range(len(seq_lens)):
|
613
|
+
if num_new_pages[i]:
|
614
|
+
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
615
|
+
else:
|
616
|
+
out_indices[i] = last_loc[i] + 1
|
617
|
+
return num_new_pages
|
618
|
+
|
619
|
+
|
620
|
+
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
621
|
+
|
622
|
+
def __init__(
|
623
|
+
self,
|
624
|
+
size: int,
|
625
|
+
page_size: int,
|
626
|
+
dtype: torch.dtype,
|
627
|
+
device: str,
|
628
|
+
kvcache: KVCache,
|
629
|
+
):
|
630
|
+
super().__init__(size, page_size, dtype, device, kvcache)
|
631
|
+
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
632
|
+
|
633
|
+
def alloc_extend(
|
634
|
+
self,
|
635
|
+
prefix_lens: torch.Tensor,
|
636
|
+
seq_lens: torch.Tensor,
|
637
|
+
last_loc: torch.Tensor,
|
638
|
+
extend_num_tokens: int,
|
639
|
+
):
|
640
|
+
if self.debug_mode:
|
641
|
+
assert torch.all(
|
642
|
+
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
643
|
+
)
|
644
|
+
|
645
|
+
bs = len(prefix_lens)
|
646
|
+
out_indices = torch.empty(
|
647
|
+
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
648
|
+
)
|
649
|
+
|
650
|
+
self.ret_values = alloc_extend_kernel_ascend(
|
651
|
+
prefix_lens,
|
652
|
+
seq_lens,
|
653
|
+
last_loc,
|
654
|
+
self.free_pages,
|
655
|
+
out_indices,
|
656
|
+
self.page_size,
|
657
|
+
self.device,
|
658
|
+
)
|
659
|
+
|
660
|
+
if self.debug_mode:
|
661
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
662
|
+
|
663
|
+
num_new_pages = self.ret_values.sum()
|
664
|
+
if num_new_pages > len(self.free_pages):
|
665
|
+
return None
|
666
|
+
|
667
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
668
|
+
return out_indices
|
669
|
+
|
670
|
+
def alloc_decode(
|
671
|
+
self,
|
672
|
+
seq_lens: torch.Tensor,
|
673
|
+
last_loc: torch.Tensor,
|
674
|
+
):
|
675
|
+
if self.debug_mode:
|
676
|
+
assert torch.all(
|
677
|
+
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
678
|
+
)
|
679
|
+
|
680
|
+
bs = len(seq_lens)
|
681
|
+
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
682
|
+
|
683
|
+
self.ret_values = alloc_decode_kernel_ascend(
|
684
|
+
seq_lens,
|
685
|
+
last_loc,
|
686
|
+
self.free_pages,
|
687
|
+
out_indices,
|
688
|
+
self.page_size,
|
689
|
+
)
|
690
|
+
|
691
|
+
if self.debug_mode:
|
692
|
+
assert len(torch.unique(out_indices)) == len(out_indices)
|
693
|
+
|
694
|
+
num_new_pages = self.ret_values.sum()
|
695
|
+
if num_new_pages > len(self.free_pages):
|
696
|
+
return None
|
697
|
+
|
698
|
+
self.free_pages = self.free_pages[num_new_pages:]
|
699
|
+
return out_indices
|
700
|
+
|
701
|
+
def clear(self):
|
702
|
+
super().clear()
|
703
|
+
self.free_pages = self.free_pages.to(torch.int32)
|
@@ -2,11 +2,14 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
4
4
|
|
5
|
-
from typing import TYPE_CHECKING, Any
|
5
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
|
-
from sglang.srt.mem_cache.allocator import
|
9
|
+
from sglang.srt.mem_cache.allocator import (
|
10
|
+
BaseTokenToKVPoolAllocator,
|
11
|
+
SWATokenToKVPoolAllocator,
|
12
|
+
)
|
10
13
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
11
14
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
12
15
|
|
@@ -63,3 +66,32 @@ class ChunkCache(BasePrefixCache):
|
|
63
66
|
|
64
67
|
def pretty_print(self):
|
65
68
|
return ""
|
69
|
+
|
70
|
+
|
71
|
+
class SWAChunkCache(ChunkCache):
|
72
|
+
"""ChunkCache with support for hybrid KV cache operations."""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
req_to_token_pool: ReqToTokenPool,
|
77
|
+
token_to_kv_pool_allocator: SWATokenToKVPoolAllocator,
|
78
|
+
page_size: int,
|
79
|
+
):
|
80
|
+
super().__init__(req_to_token_pool, token_to_kv_pool_allocator, page_size)
|
81
|
+
assert isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator)
|
82
|
+
|
83
|
+
def evict(
|
84
|
+
self,
|
85
|
+
req: Req,
|
86
|
+
prelen: int,
|
87
|
+
attention_chunk_size: int,
|
88
|
+
):
|
89
|
+
if prelen >= req.evicted_seqlen_local + attention_chunk_size:
|
90
|
+
new_evicted_seqlen_local = attention_chunk_size * (
|
91
|
+
prelen // attention_chunk_size
|
92
|
+
)
|
93
|
+
free_slots = self.req_to_token_pool.req_to_token[
|
94
|
+
req.req_pool_idx, req.evicted_seqlen_local : new_evicted_seqlen_local
|
95
|
+
]
|
96
|
+
self.token_to_kv_pool_allocator.free_swa(free_slots)
|
97
|
+
req.evicted_seqlen_local = new_evicted_seqlen_local
|
@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
|
|
34
34
|
hicache_ratio: float,
|
35
35
|
hicache_size: int,
|
36
36
|
hicache_write_policy: str,
|
37
|
+
hicache_io_backend: str,
|
37
38
|
):
|
38
39
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
39
40
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
|
|
56
57
|
page_size,
|
57
58
|
load_cache_event=self.load_cache_event,
|
58
59
|
write_policy=hicache_write_policy,
|
60
|
+
io_backend=hicache_io_backend,
|
59
61
|
)
|
60
62
|
|
61
63
|
# record the nodes with ongoing write through
|