sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -26,24 +26,15 @@ KVCache actually holds the physical kv cache.
|
|
26
26
|
|
27
27
|
import abc
|
28
28
|
import logging
|
29
|
-
import threading
|
30
|
-
from enum import IntEnum
|
31
|
-
from functools import wraps
|
32
29
|
from typing import List, Optional, Tuple, Union
|
33
30
|
|
34
31
|
import numpy as np
|
35
|
-
import psutil
|
36
32
|
import torch
|
37
33
|
import triton
|
38
34
|
import triton.language as tl
|
39
35
|
|
40
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
-
from sglang.srt.utils import
|
42
|
-
debug_timing,
|
43
|
-
get_compiler_backend,
|
44
|
-
is_cuda,
|
45
|
-
next_power_of_2,
|
46
|
-
)
|
37
|
+
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
47
38
|
|
48
39
|
logger = logging.getLogger(__name__)
|
49
40
|
|
@@ -150,15 +141,12 @@ class KVCache(abc.ABC):
|
|
150
141
|
) -> None:
|
151
142
|
raise NotImplementedError()
|
152
143
|
|
153
|
-
@abc.abstractmethod
|
154
144
|
def get_flat_data(self, indices):
|
155
145
|
raise NotImplementedError()
|
156
146
|
|
157
|
-
@abc.abstractmethod
|
158
147
|
def transfer(self, indices, flat_data):
|
159
148
|
raise NotImplementedError()
|
160
149
|
|
161
|
-
@abc.abstractmethod
|
162
150
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
163
151
|
raise NotImplementedError()
|
164
152
|
|
@@ -191,6 +179,9 @@ class TokenToKVPoolAllocator:
|
|
191
179
|
def available_size(self):
|
192
180
|
return len(self.free_slots)
|
193
181
|
|
182
|
+
def debug_print(self) -> str:
|
183
|
+
return ""
|
184
|
+
|
194
185
|
def get_kvcache(self):
|
195
186
|
return self._kvcache
|
196
187
|
|
@@ -234,6 +225,12 @@ class TokenToKVPoolAllocator:
|
|
234
225
|
self.is_not_in_free_group = True
|
235
226
|
self.free_group = []
|
236
227
|
|
228
|
+
def get_cpu_copy(self, indices):
|
229
|
+
return self._kvcache.get_cpu_copy(indices)
|
230
|
+
|
231
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
232
|
+
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
233
|
+
|
237
234
|
|
238
235
|
class MHATokenToKVPool(KVCache):
|
239
236
|
|
@@ -265,9 +262,11 @@ class MHATokenToKVPool(KVCache):
|
|
265
262
|
self.head_dim = head_dim
|
266
263
|
self._create_buffers()
|
267
264
|
|
265
|
+
# used for chunked cpu-offloading
|
266
|
+
self.chunk_size = 8192
|
268
267
|
self.layer_transfer_counter = None
|
269
268
|
self.device_module = torch.get_device_module(self.device)
|
270
|
-
self.alt_stream = self.device_module.Stream() if
|
269
|
+
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
271
270
|
|
272
271
|
k_size, v_size = self.get_kv_size_bytes()
|
273
272
|
logger.info(
|
@@ -295,6 +294,19 @@ class MHATokenToKVPool(KVCache):
|
|
295
294
|
for _ in range(self.layer_num)
|
296
295
|
]
|
297
296
|
|
297
|
+
self.data_ptrs = torch.tensor(
|
298
|
+
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
299
|
+
dtype=torch.uint64,
|
300
|
+
device=self.device,
|
301
|
+
)
|
302
|
+
self.data_strides = torch.tensor(
|
303
|
+
[
|
304
|
+
np.prod(x.shape[1:]) * x.dtype.itemsize
|
305
|
+
for x in self.k_buffer + self.v_buffer
|
306
|
+
],
|
307
|
+
device=self.device,
|
308
|
+
)
|
309
|
+
|
298
310
|
def _clear_buffers(self):
|
299
311
|
del self.k_buffer
|
300
312
|
del self.v_buffer
|
@@ -315,20 +327,61 @@ class MHATokenToKVPool(KVCache):
|
|
315
327
|
# layer_num x [seq_len, head_num, head_dim]
|
316
328
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
317
329
|
kv_data_ptrs = [
|
318
|
-
self.get_key_buffer(i).data_ptr()
|
319
|
-
|
330
|
+
self.get_key_buffer(i).data_ptr()
|
331
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
332
|
+
] + [
|
333
|
+
self.get_value_buffer(i).data_ptr()
|
334
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
335
|
+
]
|
320
336
|
kv_data_lens = [
|
321
|
-
self.get_key_buffer(i).nbytes
|
322
|
-
|
337
|
+
self.get_key_buffer(i).nbytes
|
338
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
339
|
+
] + [
|
340
|
+
self.get_value_buffer(i).nbytes
|
341
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
342
|
+
]
|
323
343
|
kv_item_lens = [
|
324
344
|
self.get_key_buffer(i)[0].nbytes * self.page_size
|
325
|
-
for i in range(self.layer_num)
|
345
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
326
346
|
] + [
|
327
347
|
self.get_value_buffer(i)[0].nbytes * self.page_size
|
328
|
-
for i in range(self.layer_num)
|
348
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
329
349
|
]
|
330
350
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
331
351
|
|
352
|
+
def get_cpu_copy(self, indices):
|
353
|
+
torch.cuda.synchronize()
|
354
|
+
kv_cache_cpu = []
|
355
|
+
for layer_id in range(self.layer_num):
|
356
|
+
kv_cache_cpu.append([])
|
357
|
+
for i in range(0, len(indices), self.chunk_size):
|
358
|
+
chunk_indices = indices[i : i + self.chunk_size]
|
359
|
+
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
360
|
+
"cpu", non_blocking=True
|
361
|
+
)
|
362
|
+
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
363
|
+
"cpu", non_blocking=True
|
364
|
+
)
|
365
|
+
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
366
|
+
torch.cuda.synchronize()
|
367
|
+
return kv_cache_cpu
|
368
|
+
|
369
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
370
|
+
torch.cuda.synchronize()
|
371
|
+
for layer_id in range(self.layer_num):
|
372
|
+
for i in range(0, len(indices), self.chunk_size):
|
373
|
+
chunk_indices = indices[i : i + self.chunk_size]
|
374
|
+
k_cpu, v_cpu = (
|
375
|
+
kv_cache_cpu[layer_id][i // self.chunk_size][0],
|
376
|
+
kv_cache_cpu[layer_id][i // self.chunk_size][1],
|
377
|
+
)
|
378
|
+
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
379
|
+
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
380
|
+
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
381
|
+
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
382
|
+
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
383
|
+
torch.cuda.synchronize()
|
384
|
+
|
332
385
|
# Todo: different memory layout
|
333
386
|
def get_flat_data(self, indices):
|
334
387
|
# prepare a large chunk of contiguous data for efficient transfer
|
@@ -411,35 +464,15 @@ class MHATokenToKVPool(KVCache):
|
|
411
464
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
412
465
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
413
466
|
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
max_fp8: float,
|
424
|
-
min_fp8: float,
|
425
|
-
):
|
426
|
-
cache_k = cache_k / k_scale
|
427
|
-
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
428
|
-
cache_v = cache_v / v_scale
|
429
|
-
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
430
|
-
cache_k = cache_k.to(dtype)
|
431
|
-
cache_v = cache_v.to(dtype)
|
432
|
-
cache_k = cache_k.view(store_dtype)
|
433
|
-
cache_v = cache_v.view(store_dtype)
|
434
|
-
return cache_k, cache_v
|
435
|
-
|
436
|
-
|
437
|
-
# This compiled version is slower in the unit test
|
438
|
-
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
439
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
440
|
-
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
441
|
-
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
442
|
-
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
467
|
+
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
468
|
+
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
469
|
+
self.data_ptrs,
|
470
|
+
self.data_strides,
|
471
|
+
tgt_loc,
|
472
|
+
src_loc,
|
473
|
+
len(tgt_loc),
|
474
|
+
next_power_of_2(len(tgt_loc)),
|
475
|
+
)
|
443
476
|
|
444
477
|
|
445
478
|
@triton.jit
|
@@ -733,368 +766,39 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
733
766
|
pass
|
734
767
|
|
735
768
|
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
@wraps(func)
|
747
|
-
def wrapper(self, *args, **kwargs):
|
748
|
-
if (not debug_only) or self.debug:
|
749
|
-
return func(self, *args, **kwargs)
|
750
|
-
with self.lock:
|
751
|
-
return func(self, *args, **kwargs)
|
752
|
-
else:
|
753
|
-
return True
|
754
|
-
|
755
|
-
return wrapper
|
756
|
-
|
757
|
-
return _decorator
|
758
|
-
|
759
|
-
|
760
|
-
class HostKVCache(abc.ABC):
|
761
|
-
|
762
|
-
def __init__(
|
763
|
-
self,
|
764
|
-
device_pool: KVCache,
|
765
|
-
host_to_device_ratio: float,
|
766
|
-
host_size: int,
|
767
|
-
pin_memory: bool,
|
768
|
-
device: str,
|
769
|
-
page_size: int,
|
770
|
-
):
|
771
|
-
self.device_pool = device_pool
|
772
|
-
self.dtype = device_pool.store_dtype
|
773
|
-
self.pin_memory = pin_memory
|
774
|
-
self.device = device
|
775
|
-
self.page_size = page_size
|
776
|
-
self.size_per_token = self.get_size_per_token()
|
777
|
-
if host_size > 0:
|
778
|
-
self.size = int(host_size * 1e9 // self.size_per_token)
|
779
|
-
else:
|
780
|
-
self.size = int(device_pool.size * host_to_device_ratio)
|
781
|
-
# Align the host memory pool size to the page size
|
782
|
-
self.size = self.size - (self.size % self.page_size)
|
783
|
-
self.start_layer = device_pool.start_layer
|
784
|
-
self.end_layer = device_pool.end_layer
|
785
|
-
|
786
|
-
assert (
|
787
|
-
self.size > device_pool.size
|
788
|
-
), "The host memory should be larger than the device memory with the current protocol"
|
789
|
-
|
790
|
-
# Verify there is enough available host memory.
|
791
|
-
host_mem = psutil.virtual_memory()
|
792
|
-
requested_bytes = self.size * self.size_per_token
|
793
|
-
# preserve at least 10GB for other usage
|
794
|
-
ten_gb = 10 * (1024**3)
|
795
|
-
if requested_bytes > host_mem.available - ten_gb:
|
796
|
-
raise ValueError(
|
797
|
-
f"Not enough host memory available. Requesting "
|
798
|
-
f"{requested_bytes / 1e9:.2f} GB but only have "
|
799
|
-
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
800
|
-
f"size of the hierarchical cache."
|
801
|
-
)
|
802
|
-
else:
|
803
|
-
logger.info(
|
804
|
-
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
805
|
-
)
|
806
|
-
|
807
|
-
self.kv_buffer = self.init_kv_buffer()
|
808
|
-
|
809
|
-
# A lock for synchronized operations on memory allocation and state transitions.
|
810
|
-
self.lock = threading.RLock()
|
811
|
-
self.debug = logger.isEnabledFor(logging.DEBUG)
|
812
|
-
self.clear()
|
813
|
-
|
814
|
-
@abc.abstractmethod
|
815
|
-
def get_size_per_token(self):
|
816
|
-
raise NotImplementedError()
|
817
|
-
|
818
|
-
@abc.abstractmethod
|
819
|
-
def init_kv_buffer(self):
|
820
|
-
raise NotImplementedError()
|
821
|
-
|
822
|
-
@abc.abstractmethod
|
823
|
-
def transfer(self, indices, flat_data):
|
824
|
-
raise NotImplementedError()
|
825
|
-
|
826
|
-
@abc.abstractmethod
|
827
|
-
def get_flat_data(self, indices):
|
828
|
-
raise NotImplementedError()
|
829
|
-
|
830
|
-
@abc.abstractmethod
|
831
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
832
|
-
raise NotImplementedError()
|
833
|
-
|
834
|
-
@abc.abstractmethod
|
835
|
-
def assign_flat_data(self, indices, flat_data):
|
836
|
-
raise NotImplementedError()
|
837
|
-
|
838
|
-
@synchronized()
|
839
|
-
def clear(self):
|
840
|
-
# Initialize memory states and tracking structures.
|
841
|
-
self.mem_state = torch.zeros(
|
842
|
-
(self.size,), dtype=torch.uint8, device=self.device
|
843
|
-
)
|
844
|
-
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
845
|
-
|
846
|
-
def available_size(self):
|
847
|
-
return len(self.free_slots)
|
848
|
-
|
849
|
-
@synchronized()
|
850
|
-
def alloc(self, need_size: int) -> torch.Tensor:
|
851
|
-
if need_size > self.available_size():
|
852
|
-
return None
|
853
|
-
|
854
|
-
select_index = self.free_slots[:need_size]
|
855
|
-
self.free_slots = self.free_slots[need_size:]
|
856
|
-
|
857
|
-
if self.debug:
|
858
|
-
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
859
|
-
|
860
|
-
return select_index
|
861
|
-
|
862
|
-
@synchronized()
|
863
|
-
def free(self, indices: torch.Tensor) -> int:
|
864
|
-
self.free_slots = torch.cat([self.free_slots, indices])
|
865
|
-
if self.debug:
|
866
|
-
self.mem_state[indices] = MemoryStateInt.IDLE
|
867
|
-
return len(indices)
|
868
|
-
|
869
|
-
@synchronized(debug_only=True)
|
870
|
-
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
871
|
-
assert len(indices) > 0, "The indices should not be empty"
|
872
|
-
states = self.mem_state[indices]
|
873
|
-
assert (
|
874
|
-
states == states[0]
|
875
|
-
).all(), "The memory slots should have the same state {}".format(states)
|
876
|
-
return MemoryStateInt(states[0].item())
|
877
|
-
|
878
|
-
@synchronized(debug_only=True)
|
879
|
-
def is_reserved(self, indices: torch.Tensor) -> bool:
|
880
|
-
return self.get_state(indices) == MemoryStateInt.RESERVED
|
881
|
-
|
882
|
-
@synchronized(debug_only=True)
|
883
|
-
def is_protected(self, indices: torch.Tensor) -> bool:
|
884
|
-
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
885
|
-
|
886
|
-
@synchronized(debug_only=True)
|
887
|
-
def is_synced(self, indices: torch.Tensor) -> bool:
|
888
|
-
return self.get_state(indices) == MemoryStateInt.SYNCED
|
889
|
-
|
890
|
-
@synchronized(debug_only=True)
|
891
|
-
def is_backup(self, indices: torch.Tensor) -> bool:
|
892
|
-
return self.get_state(indices) == MemoryStateInt.BACKUP
|
893
|
-
|
894
|
-
@synchronized(debug_only=True)
|
895
|
-
def update_backup(self, indices: torch.Tensor):
|
896
|
-
if not self.is_synced(indices):
|
897
|
-
raise ValueError(
|
898
|
-
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
899
|
-
f"Current state: {self.get_state(indices)}"
|
900
|
-
)
|
901
|
-
self.mem_state[indices] = MemoryStateInt.BACKUP
|
902
|
-
|
903
|
-
@synchronized(debug_only=True)
|
904
|
-
def update_synced(self, indices: torch.Tensor):
|
905
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
906
|
-
|
907
|
-
@synchronized(debug_only=True)
|
908
|
-
def protect_write(self, indices: torch.Tensor):
|
909
|
-
if not self.is_reserved(indices):
|
910
|
-
raise ValueError(
|
911
|
-
f"The host memory slots should be RESERVED before write operations. "
|
912
|
-
f"Current state: {self.get_state(indices)}"
|
913
|
-
)
|
914
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
915
|
-
|
916
|
-
@synchronized(debug_only=True)
|
917
|
-
def protect_load(self, indices: torch.Tensor):
|
918
|
-
if not self.is_backup(indices):
|
919
|
-
raise ValueError(
|
920
|
-
f"The host memory slots should be in BACKUP state before load operations. "
|
921
|
-
f"Current state: {self.get_state(indices)}"
|
922
|
-
)
|
923
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
924
|
-
|
925
|
-
@synchronized(debug_only=True)
|
926
|
-
def complete_io(self, indices: torch.Tensor):
|
927
|
-
if not self.is_protected(indices):
|
928
|
-
raise ValueError(
|
929
|
-
f"The host memory slots should be PROTECTED during I/O operations. "
|
930
|
-
f"Current state: {self.get_state(indices)}"
|
931
|
-
)
|
932
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
933
|
-
|
934
|
-
|
935
|
-
class MHATokenToKVPoolHost(HostKVCache):
|
936
|
-
device_pool: MHATokenToKVPool
|
937
|
-
|
938
|
-
def __init__(
|
939
|
-
self,
|
940
|
-
device_pool: MHATokenToKVPool,
|
941
|
-
host_to_device_ratio: float,
|
942
|
-
host_size: int,
|
943
|
-
page_size: int,
|
944
|
-
pin_memory: bool = True,
|
945
|
-
device: str = "cpu",
|
946
|
-
):
|
947
|
-
super().__init__(
|
948
|
-
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
949
|
-
)
|
950
|
-
|
951
|
-
def get_size_per_token(self):
|
952
|
-
self.head_num = self.device_pool.head_num
|
953
|
-
self.head_dim = self.device_pool.head_dim
|
954
|
-
self.layer_num = self.device_pool.layer_num
|
955
|
-
|
956
|
-
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
957
|
-
|
958
|
-
def init_kv_buffer(self):
|
959
|
-
return torch.empty(
|
960
|
-
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
961
|
-
dtype=self.dtype,
|
962
|
-
device=self.device,
|
963
|
-
pin_memory=self.pin_memory,
|
964
|
-
)
|
965
|
-
|
966
|
-
@debug_timing
|
967
|
-
def transfer(self, indices, flat_data):
|
968
|
-
# backup prepared data from device to host
|
969
|
-
self.kv_buffer[:, :, indices] = flat_data.to(
|
970
|
-
device=self.device, non_blocking=False
|
971
|
-
)
|
972
|
-
|
973
|
-
def get_flat_data(self, indices):
|
974
|
-
return self.kv_buffer[:, :, indices]
|
975
|
-
|
976
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
977
|
-
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
978
|
-
|
979
|
-
def assign_flat_data(self, indices, flat_data):
|
980
|
-
self.kv_buffer[:, :, indices] = flat_data
|
981
|
-
|
982
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
983
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
984
|
-
for i in range(len(device_indices_cpu)):
|
985
|
-
h_index = host_indices[i * self.page_size]
|
986
|
-
d_index = device_indices_cpu[i]
|
987
|
-
for j in range(self.layer_num):
|
988
|
-
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
989
|
-
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
990
|
-
non_blocking=True,
|
991
|
-
)
|
992
|
-
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
993
|
-
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
994
|
-
non_blocking=True,
|
995
|
-
)
|
996
|
-
|
997
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
998
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
999
|
-
for i in range(len(device_indices_cpu)):
|
1000
|
-
h_index = host_indices[i * self.page_size]
|
1001
|
-
d_index = device_indices_cpu[i]
|
1002
|
-
device_pool.k_buffer[layer_id - self.start_layer][
|
1003
|
-
d_index : d_index + self.page_size
|
1004
|
-
].copy_(
|
1005
|
-
self.kv_buffer[
|
1006
|
-
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
1007
|
-
],
|
1008
|
-
non_blocking=True,
|
1009
|
-
)
|
1010
|
-
device_pool.v_buffer[layer_id - self.start_layer][
|
1011
|
-
d_index : d_index + self.page_size
|
1012
|
-
].copy_(
|
1013
|
-
self.kv_buffer[
|
1014
|
-
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
1015
|
-
],
|
1016
|
-
non_blocking=True,
|
1017
|
-
)
|
1018
|
-
|
769
|
+
@triton.jit
|
770
|
+
def copy_all_layer_kv_cache(
|
771
|
+
data_ptrs,
|
772
|
+
strides,
|
773
|
+
tgt_loc_ptr,
|
774
|
+
src_loc_ptr,
|
775
|
+
num_locs,
|
776
|
+
num_locs_upper: tl.constexpr,
|
777
|
+
):
|
778
|
+
BLOCK_SIZE: tl.constexpr = 128
|
1019
779
|
|
1020
|
-
|
1021
|
-
|
780
|
+
bid = tl.program_id(0)
|
781
|
+
stride = tl.load(strides + bid)
|
1022
782
|
|
1023
|
-
|
1024
|
-
|
1025
|
-
device_pool: MLATokenToKVPool,
|
1026
|
-
host_to_device_ratio: float,
|
1027
|
-
host_size: int,
|
1028
|
-
page_size: int,
|
1029
|
-
pin_memory: bool = True,
|
1030
|
-
device: str = "cpu",
|
1031
|
-
):
|
1032
|
-
super().__init__(
|
1033
|
-
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
1034
|
-
)
|
783
|
+
data_ptr = tl.load(data_ptrs + bid)
|
784
|
+
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
1035
785
|
|
1036
|
-
|
1037
|
-
|
1038
|
-
|
1039
|
-
self.layer_num = self.device_pool.layer_num
|
786
|
+
num_locs_offset = tl.arange(0, num_locs_upper)
|
787
|
+
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
788
|
+
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
1040
789
|
|
1041
|
-
|
1042
|
-
|
1043
|
-
* 1
|
1044
|
-
* self.dtype.itemsize
|
1045
|
-
* self.layer_num
|
1046
|
-
)
|
790
|
+
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
791
|
+
# because this copy is an inplace operation.
|
1047
792
|
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1054
|
-
self.kv_lora_rank + self.qk_rope_head_dim,
|
1055
|
-
),
|
1056
|
-
dtype=self.dtype,
|
1057
|
-
device=self.device,
|
1058
|
-
pin_memory=self.pin_memory,
|
793
|
+
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
794
|
+
for i in range(num_loop):
|
795
|
+
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
796
|
+
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
|
797
|
+
value = tl.load(
|
798
|
+
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
1059
799
|
)
|
1060
|
-
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
self.kv_buffer[:, indices] = flat_data.to(
|
1065
|
-
device=self.device, non_blocking=False
|
800
|
+
tl.store(
|
801
|
+
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
802
|
+
value,
|
803
|
+
mask=mask,
|
1066
804
|
)
|
1067
|
-
|
1068
|
-
def get_flat_data(self, indices):
|
1069
|
-
return self.kv_buffer[:, indices]
|
1070
|
-
|
1071
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
1072
|
-
return self.kv_buffer[layer_id - self.start_layer, indices]
|
1073
|
-
|
1074
|
-
def assign_flat_data(self, indices, flat_data):
|
1075
|
-
self.kv_buffer[:, indices] = flat_data
|
1076
|
-
|
1077
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
1078
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
1079
|
-
for i in range(len(device_indices_cpu)):
|
1080
|
-
h_index = host_indices[i * self.page_size]
|
1081
|
-
d_index = device_indices_cpu[i]
|
1082
|
-
for j in range(self.layer_num):
|
1083
|
-
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
1084
|
-
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
1085
|
-
non_blocking=True,
|
1086
|
-
)
|
1087
|
-
|
1088
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
1089
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
1090
|
-
for i in range(len(device_indices_cpu)):
|
1091
|
-
h_index = host_indices[i * self.page_size]
|
1092
|
-
d_index = device_indices_cpu[i]
|
1093
|
-
device_pool.kv_buffer[layer_id - self.start_layer][
|
1094
|
-
d_index : d_index + self.page_size
|
1095
|
-
].copy_(
|
1096
|
-
self.kv_buffer[
|
1097
|
-
layer_id - self.start_layer, h_index : h_index + self.page_size
|
1098
|
-
],
|
1099
|
-
non_blocking=True,
|
1100
|
-
)
|