sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -27,16 +27,18 @@ KVCache actually holds the physical kv cache.
|
|
27
27
|
import abc
|
28
28
|
import logging
|
29
29
|
from contextlib import nullcontext
|
30
|
-
from typing import List, Optional, Tuple, Union
|
30
|
+
from typing import Dict, List, Optional, Tuple, Union
|
31
31
|
|
32
32
|
import numpy as np
|
33
33
|
import torch
|
34
|
+
import torch.distributed as dist
|
34
35
|
import triton
|
35
36
|
import triton.language as tl
|
37
|
+
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
36
38
|
|
37
39
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
38
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
40
42
|
|
41
43
|
logger = logging.getLogger(__name__)
|
42
44
|
|
@@ -66,6 +68,7 @@ class ReqToTokenPool:
|
|
66
68
|
self.req_to_token = torch.zeros(
|
67
69
|
(size, max_context_len), dtype=torch.int32, device=device
|
68
70
|
)
|
71
|
+
|
69
72
|
self.free_slots = list(range(size))
|
70
73
|
|
71
74
|
def write(self, indices, values):
|
@@ -121,6 +124,7 @@ class KVCache(abc.ABC):
|
|
121
124
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
122
125
|
enable=enable_memory_saver
|
123
126
|
)
|
127
|
+
self.mem_usage = 0
|
124
128
|
|
125
129
|
# used for chunked cpu-offloading
|
126
130
|
self.cpu_offloading_chunk_size = 8192
|
@@ -147,13 +151,16 @@ class KVCache(abc.ABC):
|
|
147
151
|
) -> None:
|
148
152
|
raise NotImplementedError()
|
149
153
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
+
@abc.abstractmethod
|
155
|
+
def load_from_host_per_layer(
|
156
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
157
|
+
):
|
154
158
|
raise NotImplementedError()
|
155
159
|
|
156
|
-
|
160
|
+
@abc.abstractmethod
|
161
|
+
def backup_to_host_all_layer(
|
162
|
+
self, host_pool, host_indices, device_indices, io_backend
|
163
|
+
):
|
157
164
|
raise NotImplementedError()
|
158
165
|
|
159
166
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
@@ -191,7 +198,6 @@ class MHATokenToKVPool(KVCache):
|
|
191
198
|
start_layer,
|
192
199
|
end_layer,
|
193
200
|
)
|
194
|
-
|
195
201
|
self.head_num = head_num
|
196
202
|
self.head_dim = head_dim
|
197
203
|
|
@@ -218,6 +224,7 @@ class MHATokenToKVPool(KVCache):
|
|
218
224
|
logger.info(
|
219
225
|
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
220
226
|
)
|
227
|
+
self.mem_usage = (k_size + v_size) / GB
|
221
228
|
|
222
229
|
def _create_buffers(self):
|
223
230
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
@@ -244,7 +251,7 @@ class MHATokenToKVPool(KVCache):
|
|
244
251
|
)
|
245
252
|
for _ in range(self.layer_num)
|
246
253
|
]
|
247
|
-
|
254
|
+
self.token_stride = self.head_num * self.head_dim
|
248
255
|
self.data_ptrs = torch.tensor(
|
249
256
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
250
257
|
dtype=torch.uint64,
|
@@ -278,24 +285,24 @@ class MHATokenToKVPool(KVCache):
|
|
278
285
|
# layer_num x [seq_len, head_num, head_dim]
|
279
286
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
280
287
|
kv_data_ptrs = [
|
281
|
-
self.
|
288
|
+
self._get_key_buffer(i).data_ptr()
|
282
289
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
283
290
|
] + [
|
284
|
-
self.
|
291
|
+
self._get_value_buffer(i).data_ptr()
|
285
292
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
286
293
|
]
|
287
294
|
kv_data_lens = [
|
288
|
-
self.
|
295
|
+
self._get_key_buffer(i).nbytes
|
289
296
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
290
297
|
] + [
|
291
|
-
self.
|
298
|
+
self._get_value_buffer(i).nbytes
|
292
299
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
293
300
|
]
|
294
301
|
kv_item_lens = [
|
295
|
-
self.
|
302
|
+
self._get_key_buffer(i)[0].nbytes * self.page_size
|
296
303
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
297
304
|
] + [
|
298
|
-
self.
|
305
|
+
self._get_value_buffer(i)[0].nbytes * self.page_size
|
299
306
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
300
307
|
]
|
301
308
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
@@ -338,49 +345,73 @@ class MHATokenToKVPool(KVCache):
|
|
338
345
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
339
346
|
torch.cuda.synchronize()
|
340
347
|
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
348
|
+
def load_from_host_per_layer(
|
349
|
+
self,
|
350
|
+
host_pool,
|
351
|
+
host_indices,
|
352
|
+
device_indices,
|
353
|
+
layer_id,
|
354
|
+
io_backend,
|
355
|
+
):
|
356
|
+
transfer_kv_per_layer(
|
357
|
+
src_k=host_pool.k_buffer[layer_id],
|
358
|
+
dst_k=self.k_buffer[layer_id],
|
359
|
+
src_v=host_pool.v_buffer[layer_id],
|
360
|
+
dst_v=self.v_buffer[layer_id],
|
361
|
+
src_indices=host_indices,
|
362
|
+
dst_indices=device_indices,
|
363
|
+
io_backend=io_backend,
|
364
|
+
page_size=self.page_size,
|
365
|
+
item_size=self.token_stride,
|
349
366
|
)
|
350
|
-
return flatten
|
351
|
-
|
352
|
-
@debug_timing
|
353
|
-
def transfer(self, indices, flat_data):
|
354
|
-
# transfer prepared data from host to device
|
355
|
-
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
356
|
-
k_data, v_data = flat_data[0], flat_data[1]
|
357
|
-
for i in range(self.layer_num):
|
358
|
-
self.k_buffer[i][indices] = k_data[i]
|
359
|
-
self.v_buffer[i][indices] = v_data[i]
|
360
|
-
|
361
|
-
def transfer_per_layer(self, indices, flat_data, layer_id):
|
362
|
-
# transfer prepared data from host to device
|
363
|
-
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
364
|
-
k_data, v_data = flat_data[0], flat_data[1]
|
365
|
-
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
366
|
-
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
367
367
|
|
368
|
-
def
|
369
|
-
|
370
|
-
|
368
|
+
def backup_to_host_all_layer(
|
369
|
+
self, host_pool, host_indices, device_indices, io_backend
|
370
|
+
):
|
371
|
+
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
372
|
+
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
373
|
+
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
374
|
+
raise ValueError(
|
375
|
+
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
376
|
+
)
|
377
|
+
transfer_kv_per_layer(
|
378
|
+
src_k=self.k_buffer[layer_id],
|
379
|
+
dst_k=host_pool.k_buffer[layer_id],
|
380
|
+
src_v=self.v_buffer[layer_id],
|
381
|
+
dst_v=host_pool.v_buffer[layer_id],
|
382
|
+
src_indices=device_indices,
|
383
|
+
dst_indices=host_indices,
|
384
|
+
io_backend=io_backend,
|
385
|
+
page_size=self.page_size,
|
386
|
+
item_size=self.token_stride,
|
387
|
+
)
|
371
388
|
|
389
|
+
def _get_key_buffer(self, layer_id: int):
|
390
|
+
# for internal use of referencing
|
372
391
|
if self.store_dtype != self.dtype:
|
373
392
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
374
393
|
return self.k_buffer[layer_id - self.start_layer]
|
375
394
|
|
376
|
-
def
|
395
|
+
def get_key_buffer(self, layer_id: int):
|
396
|
+
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
|
397
|
+
# it is supposed to be used only by attention backend not for information purpose
|
398
|
+
# same applies to get_value_buffer and get_kv_buffer
|
377
399
|
if self.layer_transfer_counter is not None:
|
378
400
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
379
401
|
|
402
|
+
return self._get_key_buffer(layer_id)
|
403
|
+
|
404
|
+
def _get_value_buffer(self, layer_id: int):
|
405
|
+
# for internal use of referencing
|
380
406
|
if self.store_dtype != self.dtype:
|
381
407
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
382
408
|
return self.v_buffer[layer_id - self.start_layer]
|
383
409
|
|
410
|
+
def get_value_buffer(self, layer_id: int):
|
411
|
+
if self.layer_transfer_counter is not None:
|
412
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
413
|
+
return self._get_value_buffer(layer_id)
|
414
|
+
|
384
415
|
def get_kv_buffer(self, layer_id: int):
|
385
416
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
386
417
|
|
@@ -392,10 +423,14 @@ class MHATokenToKVPool(KVCache):
|
|
392
423
|
cache_v: torch.Tensor,
|
393
424
|
k_scale: Optional[float] = None,
|
394
425
|
v_scale: Optional[float] = None,
|
426
|
+
layer_id_override: Optional[int] = None,
|
395
427
|
):
|
396
428
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
397
429
|
|
398
|
-
|
430
|
+
if layer_id_override is not None:
|
431
|
+
layer_id = layer_id_override
|
432
|
+
else:
|
433
|
+
layer_id = layer.layer_id
|
399
434
|
if cache_k.dtype != self.dtype:
|
400
435
|
if k_scale is not None:
|
401
436
|
cache_k.div_(k_scale)
|
@@ -431,6 +466,206 @@ class MHATokenToKVPool(KVCache):
|
|
431
466
|
)
|
432
467
|
|
433
468
|
|
469
|
+
class SWAKVPool(KVCache):
|
470
|
+
"""KV cache with separate pools for full and SWA attention layers."""
|
471
|
+
|
472
|
+
def __init__(
|
473
|
+
self,
|
474
|
+
size: int,
|
475
|
+
size_swa: int,
|
476
|
+
dtype: torch.dtype,
|
477
|
+
head_num: int,
|
478
|
+
head_dim: int,
|
479
|
+
swa_attention_layer_ids: List[int],
|
480
|
+
full_attention_layer_ids: List[int],
|
481
|
+
enable_kvcache_transpose: bool,
|
482
|
+
device: str,
|
483
|
+
):
|
484
|
+
self.size = size
|
485
|
+
self.size_swa = size_swa
|
486
|
+
self.dtype = dtype
|
487
|
+
self.device = device
|
488
|
+
self.swa_layer_nums = len(swa_attention_layer_ids)
|
489
|
+
self.full_layer_nums = len(full_attention_layer_ids)
|
490
|
+
self.page_size = 1
|
491
|
+
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
492
|
+
assert not enable_kvcache_transpose
|
493
|
+
TokenToKVPoolClass = MHATokenToKVPool
|
494
|
+
self.swa_kv_pool = TokenToKVPoolClass(
|
495
|
+
size=size_swa,
|
496
|
+
page_size=self.page_size,
|
497
|
+
dtype=dtype,
|
498
|
+
head_num=head_num,
|
499
|
+
head_dim=head_dim,
|
500
|
+
layer_num=self.swa_layer_nums,
|
501
|
+
device=device,
|
502
|
+
enable_memory_saver=False,
|
503
|
+
)
|
504
|
+
self.full_kv_pool = TokenToKVPoolClass(
|
505
|
+
size=size,
|
506
|
+
page_size=self.page_size,
|
507
|
+
dtype=dtype,
|
508
|
+
head_num=head_num,
|
509
|
+
head_dim=head_dim,
|
510
|
+
layer_num=self.full_layer_nums,
|
511
|
+
device=device,
|
512
|
+
enable_memory_saver=False,
|
513
|
+
)
|
514
|
+
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
515
|
+
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
516
|
+
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
|
517
|
+
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
|
518
|
+
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
519
|
+
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
520
|
+
|
521
|
+
def get_kv_size_bytes(self):
|
522
|
+
raise NotImplementedError
|
523
|
+
|
524
|
+
def get_contiguous_buf_infos(self):
|
525
|
+
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
526
|
+
self.full_kv_pool.get_contiguous_buf_infos()
|
527
|
+
)
|
528
|
+
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
529
|
+
self.swa_kv_pool.get_contiguous_buf_infos()
|
530
|
+
)
|
531
|
+
|
532
|
+
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
|
533
|
+
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
534
|
+
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
535
|
+
|
536
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
537
|
+
|
538
|
+
def get_key_buffer(self, layer_id: int):
|
539
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
540
|
+
if is_swa:
|
541
|
+
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
|
542
|
+
else:
|
543
|
+
return self.full_kv_pool.get_key_buffer(layer_id_pool)
|
544
|
+
|
545
|
+
def get_value_buffer(self, layer_id: int):
|
546
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
547
|
+
if is_swa:
|
548
|
+
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
|
549
|
+
else:
|
550
|
+
return self.full_kv_pool.get_value_buffer(layer_id_pool)
|
551
|
+
|
552
|
+
def get_kv_buffer(self, layer_id: int):
|
553
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
554
|
+
if is_swa:
|
555
|
+
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
|
556
|
+
else:
|
557
|
+
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
|
558
|
+
|
559
|
+
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
|
560
|
+
assert self.full_to_swa_index_mapping is not None
|
561
|
+
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
|
562
|
+
|
563
|
+
def set_kv_buffer(
|
564
|
+
self,
|
565
|
+
layer: RadixAttention,
|
566
|
+
loc: torch.Tensor,
|
567
|
+
cache_k: torch.Tensor,
|
568
|
+
cache_v: torch.Tensor,
|
569
|
+
k_scale: float = 1.0,
|
570
|
+
v_scale: float = 1.0,
|
571
|
+
):
|
572
|
+
|
573
|
+
layer_id = layer.layer_id
|
574
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
575
|
+
if is_swa:
|
576
|
+
if self.full_to_swa_index_mapping is not None:
|
577
|
+
loc = self.translate_loc_from_full_to_swa(loc)
|
578
|
+
self.swa_kv_pool.set_kv_buffer(
|
579
|
+
None,
|
580
|
+
loc,
|
581
|
+
cache_k,
|
582
|
+
cache_v,
|
583
|
+
k_scale,
|
584
|
+
v_scale,
|
585
|
+
layer_id_override=layer_id_pool,
|
586
|
+
)
|
587
|
+
else:
|
588
|
+
self.full_kv_pool.set_kv_buffer(
|
589
|
+
None,
|
590
|
+
loc,
|
591
|
+
cache_k,
|
592
|
+
cache_v,
|
593
|
+
k_scale,
|
594
|
+
v_scale,
|
595
|
+
layer_id_override=layer_id_pool,
|
596
|
+
)
|
597
|
+
|
598
|
+
|
599
|
+
class AscendTokenToKVPool(MHATokenToKVPool):
|
600
|
+
|
601
|
+
def _create_buffers(self):
|
602
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
603
|
+
# [size, head_num, head_dim] for each layer
|
604
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
605
|
+
self.k_buffer = [
|
606
|
+
torch.zeros(
|
607
|
+
(
|
608
|
+
self.size // self.page_size + 1,
|
609
|
+
self.page_size,
|
610
|
+
self.head_num,
|
611
|
+
self.head_dim,
|
612
|
+
),
|
613
|
+
dtype=self.store_dtype,
|
614
|
+
device=self.device,
|
615
|
+
)
|
616
|
+
for _ in range(self.layer_num)
|
617
|
+
]
|
618
|
+
self.v_buffer = [
|
619
|
+
torch.zeros(
|
620
|
+
(
|
621
|
+
self.size // self.page_size + 1,
|
622
|
+
self.page_size,
|
623
|
+
self.head_num,
|
624
|
+
self.head_dim,
|
625
|
+
),
|
626
|
+
dtype=self.store_dtype,
|
627
|
+
device=self.device,
|
628
|
+
)
|
629
|
+
for _ in range(self.layer_num)
|
630
|
+
]
|
631
|
+
|
632
|
+
def set_kv_buffer(
|
633
|
+
self,
|
634
|
+
layer: RadixAttention,
|
635
|
+
loc: torch.Tensor,
|
636
|
+
cache_k: torch.Tensor,
|
637
|
+
cache_v: torch.Tensor,
|
638
|
+
k_scale: Optional[float] = None,
|
639
|
+
v_scale: Optional[float] = None,
|
640
|
+
):
|
641
|
+
layer_id = layer.layer_id
|
642
|
+
if cache_k.dtype != self.dtype:
|
643
|
+
if k_scale is not None:
|
644
|
+
cache_k.div_(k_scale)
|
645
|
+
if v_scale is not None:
|
646
|
+
cache_v.div_(v_scale)
|
647
|
+
cache_k = cache_k.to(self.dtype)
|
648
|
+
cache_v = cache_v.to(self.dtype)
|
649
|
+
|
650
|
+
if self.store_dtype != self.dtype:
|
651
|
+
cache_k = cache_k.view(self.store_dtype)
|
652
|
+
cache_v = cache_v.view(self.store_dtype)
|
653
|
+
|
654
|
+
import torch_npu
|
655
|
+
|
656
|
+
torch_npu._npu_reshape_and_cache(
|
657
|
+
key=cache_k,
|
658
|
+
value=cache_v,
|
659
|
+
key_cache=self.k_buffer[layer_id].view(
|
660
|
+
-1, self.page_size, self.head_num, self.head_dim
|
661
|
+
),
|
662
|
+
value_cache=self.v_buffer[layer_id].view(
|
663
|
+
-1, self.page_size, self.head_num, self.head_dim
|
664
|
+
),
|
665
|
+
slot_indices=loc,
|
666
|
+
)
|
667
|
+
|
668
|
+
|
434
669
|
@triton.jit
|
435
670
|
def set_mla_kv_buffer_kernel(
|
436
671
|
kv_buffer_ptr,
|
@@ -554,12 +789,14 @@ class MLATokenToKVPool(KVCache):
|
|
554
789
|
for _ in range(layer_num)
|
555
790
|
]
|
556
791
|
|
792
|
+
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
557
793
|
self.layer_transfer_counter = None
|
558
794
|
|
559
795
|
kv_size = self.get_kv_size_bytes()
|
560
796
|
logger.info(
|
561
797
|
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
562
798
|
)
|
799
|
+
self.mem_usage = kv_size / GB
|
563
800
|
|
564
801
|
def get_kv_size_bytes(self):
|
565
802
|
assert hasattr(self, "kv_buffer")
|
@@ -638,21 +875,37 @@ class MLATokenToKVPool(KVCache):
|
|
638
875
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
639
876
|
)
|
640
877
|
|
641
|
-
def
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
self.
|
878
|
+
def load_from_host_per_layer(
|
879
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
880
|
+
):
|
881
|
+
transfer_kv_per_layer_mla(
|
882
|
+
src=host_pool.kv_buffer[layer_id],
|
883
|
+
dst=self.kv_buffer[layer_id],
|
884
|
+
src_indices=host_indices,
|
885
|
+
dst_indices=device_indices,
|
886
|
+
io_backend=io_backend,
|
887
|
+
page_size=self.page_size,
|
888
|
+
item_size=self.token_stride,
|
889
|
+
)
|
651
890
|
|
652
|
-
def
|
653
|
-
|
654
|
-
|
655
|
-
|
891
|
+
def backup_to_host_all_layer(
|
892
|
+
self, host_pool, host_indices, device_indices, io_backend
|
893
|
+
):
|
894
|
+
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
895
|
+
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
896
|
+
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
897
|
+
raise ValueError(
|
898
|
+
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
899
|
+
)
|
900
|
+
transfer_kv_per_layer_mla(
|
901
|
+
src=self.kv_buffer[layer_id],
|
902
|
+
dst=host_pool.kv_buffer[layer_id],
|
903
|
+
src_indices=device_indices,
|
904
|
+
dst_indices=host_indices,
|
905
|
+
io_backend=io_backend,
|
906
|
+
page_size=self.page_size,
|
907
|
+
item_size=self.token_stride,
|
908
|
+
)
|
656
909
|
|
657
910
|
def get_cpu_copy(self, indices):
|
658
911
|
torch.cuda.synchronize()
|
@@ -682,6 +935,84 @@ class MLATokenToKVPool(KVCache):
|
|
682
935
|
torch.cuda.synchronize()
|
683
936
|
|
684
937
|
|
938
|
+
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
939
|
+
def __init__(
|
940
|
+
self,
|
941
|
+
size: int,
|
942
|
+
page_size: int,
|
943
|
+
dtype: torch.dtype,
|
944
|
+
kv_lora_rank: int,
|
945
|
+
qk_rope_head_dim: int,
|
946
|
+
layer_num: int,
|
947
|
+
device: str,
|
948
|
+
enable_memory_saver: bool,
|
949
|
+
start_layer: Optional[int] = None,
|
950
|
+
end_layer: Optional[int] = None,
|
951
|
+
):
|
952
|
+
super(MLATokenToKVPool, self).__init__(
|
953
|
+
size,
|
954
|
+
page_size,
|
955
|
+
dtype,
|
956
|
+
layer_num,
|
957
|
+
device,
|
958
|
+
enable_memory_saver,
|
959
|
+
start_layer,
|
960
|
+
end_layer,
|
961
|
+
)
|
962
|
+
|
963
|
+
self.kv_lora_rank = kv_lora_rank
|
964
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
965
|
+
|
966
|
+
self.custom_mem_pool = None
|
967
|
+
|
968
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
969
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
970
|
+
self.kv_buffer = [
|
971
|
+
torch.zeros(
|
972
|
+
(
|
973
|
+
self.size // self.page_size + 1,
|
974
|
+
self.page_size,
|
975
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
976
|
+
),
|
977
|
+
dtype=self.store_dtype,
|
978
|
+
device=self.device,
|
979
|
+
)
|
980
|
+
for _ in range(layer_num)
|
981
|
+
]
|
982
|
+
|
983
|
+
self.layer_transfer_counter = None
|
984
|
+
|
985
|
+
kv_size = self.get_kv_size_bytes()
|
986
|
+
logger.info(
|
987
|
+
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
988
|
+
)
|
989
|
+
self.mem_usage = kv_size / GB
|
990
|
+
|
991
|
+
def set_kv_buffer(
|
992
|
+
self,
|
993
|
+
layer: RadixAttention,
|
994
|
+
loc: torch.Tensor,
|
995
|
+
cache_k: torch.Tensor,
|
996
|
+
cache_v: torch.Tensor,
|
997
|
+
):
|
998
|
+
layer_id = layer.layer_id
|
999
|
+
if cache_k.dtype != self.dtype:
|
1000
|
+
cache_k = cache_k.to(self.dtype)
|
1001
|
+
|
1002
|
+
if self.store_dtype != self.dtype:
|
1003
|
+
cache_k = cache_k.view(store_dtype)
|
1004
|
+
|
1005
|
+
import torch_npu
|
1006
|
+
|
1007
|
+
torch_npu._npu_reshape_and_cache_siso(
|
1008
|
+
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
1009
|
+
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
1010
|
+
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
1011
|
+
),
|
1012
|
+
slot_indices=loc,
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
|
685
1016
|
class DoubleSparseTokenToKVPool(KVCache):
|
686
1017
|
def __init__(
|
687
1018
|
self,
|
@@ -760,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
760
1091
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
761
1092
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
762
1093
|
|
763
|
-
def
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
1094
|
+
def load_from_host_per_layer(
|
1095
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
1096
|
+
):
|
1097
|
+
raise NotImplementedError(
|
1098
|
+
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1099
|
+
)
|
768
1100
|
|
769
|
-
def
|
770
|
-
|
1101
|
+
def backup_to_host_all_layer(
|
1102
|
+
self, host_pool, host_indices, device_indices, io_backend
|
1103
|
+
):
|
1104
|
+
raise NotImplementedError(
|
1105
|
+
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1106
|
+
)
|
771
1107
|
|
772
1108
|
|
773
1109
|
@triton.jit
|