sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -26,24 +26,17 @@ KVCache actually holds the physical kv cache.
|
|
26
26
|
|
27
27
|
import abc
|
28
28
|
import logging
|
29
|
-
import
|
30
|
-
from enum import IntEnum
|
31
|
-
from functools import wraps
|
29
|
+
from contextlib import nullcontext
|
32
30
|
from typing import List, Optional, Tuple, Union
|
33
31
|
|
34
32
|
import numpy as np
|
35
|
-
import psutil
|
36
33
|
import torch
|
37
34
|
import triton
|
38
35
|
import triton.language as tl
|
39
36
|
|
37
|
+
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
40
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
-
from sglang.srt.utils import
|
42
|
-
debug_timing,
|
43
|
-
get_compiler_backend,
|
44
|
-
is_cuda,
|
45
|
-
next_power_of_2,
|
46
|
-
)
|
39
|
+
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
47
40
|
|
48
41
|
logger = logging.getLogger(__name__)
|
49
42
|
|
@@ -61,6 +54,7 @@ class ReqToTokenPool:
|
|
61
54
|
device: str,
|
62
55
|
enable_memory_saver: bool,
|
63
56
|
):
|
57
|
+
|
64
58
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
65
59
|
enable=enable_memory_saver
|
66
60
|
)
|
@@ -68,7 +62,7 @@ class ReqToTokenPool:
|
|
68
62
|
self.size = size
|
69
63
|
self.max_context_len = max_context_len
|
70
64
|
self.device = device
|
71
|
-
with memory_saver_adapter.region():
|
65
|
+
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
72
66
|
self.req_to_token = torch.zeros(
|
73
67
|
(size, max_context_len), dtype=torch.int32, device=device
|
74
68
|
)
|
@@ -128,6 +122,9 @@ class KVCache(abc.ABC):
|
|
128
122
|
enable=enable_memory_saver
|
129
123
|
)
|
130
124
|
|
125
|
+
# used for chunked cpu-offloading
|
126
|
+
self.cpu_offloading_chunk_size = 8192
|
127
|
+
|
131
128
|
@abc.abstractmethod
|
132
129
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
133
130
|
raise NotImplementedError()
|
@@ -150,89 +147,23 @@ class KVCache(abc.ABC):
|
|
150
147
|
) -> None:
|
151
148
|
raise NotImplementedError()
|
152
149
|
|
153
|
-
@abc.abstractmethod
|
154
150
|
def get_flat_data(self, indices):
|
155
151
|
raise NotImplementedError()
|
156
152
|
|
157
|
-
@abc.abstractmethod
|
158
153
|
def transfer(self, indices, flat_data):
|
159
154
|
raise NotImplementedError()
|
160
155
|
|
161
|
-
@abc.abstractmethod
|
162
156
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
163
157
|
raise NotImplementedError()
|
164
158
|
|
165
159
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
166
160
|
self.layer_transfer_counter = layer_transfer_counter
|
167
161
|
|
162
|
+
def get_cpu_copy(self, indices):
|
163
|
+
raise NotImplementedError()
|
168
164
|
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
def __init__(
|
173
|
-
self,
|
174
|
-
size: int,
|
175
|
-
dtype: torch.dtype,
|
176
|
-
device: str,
|
177
|
-
kvcache: KVCache,
|
178
|
-
):
|
179
|
-
self.size = size
|
180
|
-
self.dtype = dtype
|
181
|
-
self.device = device
|
182
|
-
self.page_size = 1
|
183
|
-
|
184
|
-
self.free_slots = None
|
185
|
-
self.is_not_in_free_group = True
|
186
|
-
self.free_group = []
|
187
|
-
self.clear()
|
188
|
-
|
189
|
-
self._kvcache = kvcache
|
190
|
-
|
191
|
-
def available_size(self):
|
192
|
-
return len(self.free_slots)
|
193
|
-
|
194
|
-
def get_kvcache(self):
|
195
|
-
return self._kvcache
|
196
|
-
|
197
|
-
def alloc(self, need_size: int):
|
198
|
-
if need_size > len(self.free_slots):
|
199
|
-
return None
|
200
|
-
|
201
|
-
select_index = self.free_slots[:need_size]
|
202
|
-
self.free_slots = self.free_slots[need_size:]
|
203
|
-
return select_index
|
204
|
-
|
205
|
-
def free(self, free_index: torch.Tensor):
|
206
|
-
if free_index.numel() == 0:
|
207
|
-
return
|
208
|
-
|
209
|
-
if self.is_not_in_free_group:
|
210
|
-
self.free_slots = torch.cat((self.free_slots, free_index))
|
211
|
-
else:
|
212
|
-
self.free_group.append(free_index)
|
213
|
-
|
214
|
-
def free_group_begin(self):
|
215
|
-
self.is_not_in_free_group = False
|
216
|
-
self.free_group = []
|
217
|
-
|
218
|
-
def free_group_end(self):
|
219
|
-
self.is_not_in_free_group = True
|
220
|
-
if self.free_group:
|
221
|
-
self.free(torch.cat(self.free_group))
|
222
|
-
|
223
|
-
def backup_state(self):
|
224
|
-
return self.free_slots
|
225
|
-
|
226
|
-
def restore_state(self, free_slots):
|
227
|
-
self.free_slots = free_slots
|
228
|
-
|
229
|
-
def clear(self):
|
230
|
-
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
231
|
-
self.free_slots = torch.arange(
|
232
|
-
1, self.size + 1, dtype=torch.int64, device=self.device
|
233
|
-
)
|
234
|
-
self.is_not_in_free_group = True
|
235
|
-
self.free_group = []
|
165
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
166
|
+
raise NotImplementedError()
|
236
167
|
|
237
168
|
|
238
169
|
class MHATokenToKVPool(KVCache):
|
@@ -263,11 +194,25 @@ class MHATokenToKVPool(KVCache):
|
|
263
194
|
|
264
195
|
self.head_num = head_num
|
265
196
|
self.head_dim = head_dim
|
197
|
+
|
198
|
+
# for disagg with nvlink
|
199
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
200
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
201
|
+
)
|
202
|
+
if self.enable_custom_mem_pool:
|
203
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
204
|
+
from mooncake.allocator import NVLinkAllocator
|
205
|
+
|
206
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
207
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
208
|
+
else:
|
209
|
+
self.custom_mem_pool = None
|
210
|
+
|
266
211
|
self._create_buffers()
|
267
212
|
|
268
213
|
self.layer_transfer_counter = None
|
269
214
|
self.device_module = torch.get_device_module(self.device)
|
270
|
-
self.alt_stream = self.device_module.Stream() if
|
215
|
+
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
271
216
|
|
272
217
|
k_size, v_size = self.get_kv_size_bytes()
|
273
218
|
logger.info(
|
@@ -275,25 +220,43 @@ class MHATokenToKVPool(KVCache):
|
|
275
220
|
)
|
276
221
|
|
277
222
|
def _create_buffers(self):
|
278
|
-
with self.memory_saver_adapter.region():
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
223
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
224
|
+
with (
|
225
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
226
|
+
if self.enable_custom_mem_pool
|
227
|
+
else nullcontext()
|
228
|
+
):
|
229
|
+
# [size, head_num, head_dim] for each layer
|
230
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
231
|
+
self.k_buffer = [
|
232
|
+
torch.zeros(
|
233
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
234
|
+
dtype=self.store_dtype,
|
235
|
+
device=self.device,
|
236
|
+
)
|
237
|
+
for _ in range(self.layer_num)
|
238
|
+
]
|
239
|
+
self.v_buffer = [
|
240
|
+
torch.zeros(
|
241
|
+
(self.size + self.page_size, self.head_num, self.head_dim),
|
242
|
+
dtype=self.store_dtype,
|
243
|
+
device=self.device,
|
244
|
+
)
|
245
|
+
for _ in range(self.layer_num)
|
246
|
+
]
|
247
|
+
|
248
|
+
self.data_ptrs = torch.tensor(
|
249
|
+
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
250
|
+
dtype=torch.uint64,
|
251
|
+
device=self.device,
|
252
|
+
)
|
253
|
+
self.data_strides = torch.tensor(
|
254
|
+
[
|
255
|
+
np.prod(x.shape[1:]) * x.dtype.itemsize
|
256
|
+
for x in self.k_buffer + self.v_buffer
|
257
|
+
],
|
258
|
+
device=self.device,
|
259
|
+
)
|
297
260
|
|
298
261
|
def _clear_buffers(self):
|
299
262
|
del self.k_buffer
|
@@ -315,20 +278,66 @@ class MHATokenToKVPool(KVCache):
|
|
315
278
|
# layer_num x [seq_len, head_num, head_dim]
|
316
279
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
317
280
|
kv_data_ptrs = [
|
318
|
-
self.get_key_buffer(i).data_ptr()
|
319
|
-
|
281
|
+
self.get_key_buffer(i).data_ptr()
|
282
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
283
|
+
] + [
|
284
|
+
self.get_value_buffer(i).data_ptr()
|
285
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
286
|
+
]
|
320
287
|
kv_data_lens = [
|
321
|
-
self.get_key_buffer(i).nbytes
|
322
|
-
|
288
|
+
self.get_key_buffer(i).nbytes
|
289
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
290
|
+
] + [
|
291
|
+
self.get_value_buffer(i).nbytes
|
292
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
293
|
+
]
|
323
294
|
kv_item_lens = [
|
324
295
|
self.get_key_buffer(i)[0].nbytes * self.page_size
|
325
|
-
for i in range(self.layer_num)
|
296
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
326
297
|
] + [
|
327
298
|
self.get_value_buffer(i)[0].nbytes * self.page_size
|
328
|
-
for i in range(self.layer_num)
|
299
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
329
300
|
]
|
330
301
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
331
302
|
|
303
|
+
def maybe_get_custom_mem_pool(self):
|
304
|
+
return self.custom_mem_pool
|
305
|
+
|
306
|
+
def get_cpu_copy(self, indices):
|
307
|
+
torch.cuda.synchronize()
|
308
|
+
kv_cache_cpu = []
|
309
|
+
chunk_size = self.cpu_offloading_chunk_size
|
310
|
+
for layer_id in range(self.layer_num):
|
311
|
+
kv_cache_cpu.append([])
|
312
|
+
for i in range(0, len(indices), chunk_size):
|
313
|
+
chunk_indices = indices[i : i + chunk_size]
|
314
|
+
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
315
|
+
"cpu", non_blocking=True
|
316
|
+
)
|
317
|
+
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
318
|
+
"cpu", non_blocking=True
|
319
|
+
)
|
320
|
+
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
321
|
+
torch.cuda.synchronize()
|
322
|
+
return kv_cache_cpu
|
323
|
+
|
324
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
325
|
+
torch.cuda.synchronize()
|
326
|
+
chunk_size = self.cpu_offloading_chunk_size
|
327
|
+
for layer_id in range(self.layer_num):
|
328
|
+
for i in range(0, len(indices), chunk_size):
|
329
|
+
chunk_indices = indices[i : i + chunk_size]
|
330
|
+
k_cpu, v_cpu = (
|
331
|
+
kv_cache_cpu[layer_id][i // chunk_size][0],
|
332
|
+
kv_cache_cpu[layer_id][i // chunk_size][1],
|
333
|
+
)
|
334
|
+
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
335
|
+
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
336
|
+
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
337
|
+
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
338
|
+
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
339
|
+
torch.cuda.synchronize()
|
340
|
+
|
332
341
|
# Todo: different memory layout
|
333
342
|
def get_flat_data(self, indices):
|
334
343
|
# prepare a large chunk of contiguous data for efficient transfer
|
@@ -411,35 +420,15 @@ class MHATokenToKVPool(KVCache):
|
|
411
420
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
412
421
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
413
422
|
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
max_fp8: float,
|
424
|
-
min_fp8: float,
|
425
|
-
):
|
426
|
-
cache_k = cache_k / k_scale
|
427
|
-
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
428
|
-
cache_v = cache_v / v_scale
|
429
|
-
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
430
|
-
cache_k = cache_k.to(dtype)
|
431
|
-
cache_v = cache_v.to(dtype)
|
432
|
-
cache_k = cache_k.view(store_dtype)
|
433
|
-
cache_v = cache_v.view(store_dtype)
|
434
|
-
return cache_k, cache_v
|
435
|
-
|
436
|
-
|
437
|
-
# This compiled version is slower in the unit test
|
438
|
-
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
439
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
440
|
-
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
441
|
-
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
442
|
-
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
423
|
+
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
424
|
+
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
425
|
+
self.data_ptrs,
|
426
|
+
self.data_strides,
|
427
|
+
tgt_loc,
|
428
|
+
src_loc,
|
429
|
+
len(tgt_loc),
|
430
|
+
next_power_of_2(len(tgt_loc)),
|
431
|
+
)
|
443
432
|
|
444
433
|
|
445
434
|
@triton.jit
|
@@ -536,16 +525,34 @@ class MLATokenToKVPool(KVCache):
|
|
536
525
|
self.kv_lora_rank = kv_lora_rank
|
537
526
|
self.qk_rope_head_dim = qk_rope_head_dim
|
538
527
|
|
539
|
-
with
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
528
|
+
# for disagg with nvlink
|
529
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
530
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
531
|
+
)
|
532
|
+
if self.enable_custom_mem_pool:
|
533
|
+
# TODO(shangming): abstract custom allocator class for more backends
|
534
|
+
from mooncake.allocator import NVLinkAllocator
|
535
|
+
|
536
|
+
allocator = NVLinkAllocator.get_allocator(self.device)
|
537
|
+
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
538
|
+
else:
|
539
|
+
self.custom_mem_pool = None
|
540
|
+
|
541
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
542
|
+
with (
|
543
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
544
|
+
if self.custom_mem_pool
|
545
|
+
else nullcontext()
|
546
|
+
):
|
547
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
548
|
+
self.kv_buffer = [
|
549
|
+
torch.zeros(
|
550
|
+
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
551
|
+
dtype=self.store_dtype,
|
552
|
+
device=device,
|
553
|
+
)
|
554
|
+
for _ in range(layer_num)
|
555
|
+
]
|
549
556
|
|
550
557
|
self.layer_transfer_counter = None
|
551
558
|
|
@@ -571,6 +578,9 @@ class MLATokenToKVPool(KVCache):
|
|
571
578
|
]
|
572
579
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
573
580
|
|
581
|
+
def maybe_get_custom_mem_pool(self):
|
582
|
+
return self.custom_mem_pool
|
583
|
+
|
574
584
|
def get_key_buffer(self, layer_id: int):
|
575
585
|
if self.layer_transfer_counter is not None:
|
576
586
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
@@ -644,6 +654,33 @@ class MLATokenToKVPool(KVCache):
|
|
644
654
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
645
655
|
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
646
656
|
|
657
|
+
def get_cpu_copy(self, indices):
|
658
|
+
torch.cuda.synchronize()
|
659
|
+
kv_cache_cpu = []
|
660
|
+
chunk_size = self.cpu_offloading_chunk_size
|
661
|
+
for layer_id in range(self.layer_num):
|
662
|
+
kv_cache_cpu.append([])
|
663
|
+
for i in range(0, len(indices), chunk_size):
|
664
|
+
chunk_indices = indices[i : i + chunk_size]
|
665
|
+
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
|
666
|
+
"cpu", non_blocking=True
|
667
|
+
)
|
668
|
+
kv_cache_cpu[-1].append(kv_cpu)
|
669
|
+
torch.cuda.synchronize()
|
670
|
+
return kv_cache_cpu
|
671
|
+
|
672
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
673
|
+
torch.cuda.synchronize()
|
674
|
+
chunk_size = self.cpu_offloading_chunk_size
|
675
|
+
for layer_id in range(self.layer_num):
|
676
|
+
for i in range(0, len(indices), chunk_size):
|
677
|
+
chunk_indices = indices[i : i + chunk_size]
|
678
|
+
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
|
679
|
+
assert kv_cpu.shape[0] == len(chunk_indices)
|
680
|
+
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
|
681
|
+
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
|
682
|
+
torch.cuda.synchronize()
|
683
|
+
|
647
684
|
|
648
685
|
class DoubleSparseTokenToKVPool(KVCache):
|
649
686
|
def __init__(
|
@@ -671,7 +708,7 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
671
708
|
end_layer,
|
672
709
|
)
|
673
710
|
|
674
|
-
with self.memory_saver_adapter.region():
|
711
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
675
712
|
# [size, head_num, head_dim] for each layer
|
676
713
|
self.k_buffer = [
|
677
714
|
torch.zeros(
|
@@ -733,368 +770,39 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
733
770
|
pass
|
734
771
|
|
735
772
|
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
@wraps(func)
|
747
|
-
def wrapper(self, *args, **kwargs):
|
748
|
-
if (not debug_only) or self.debug:
|
749
|
-
return func(self, *args, **kwargs)
|
750
|
-
with self.lock:
|
751
|
-
return func(self, *args, **kwargs)
|
752
|
-
else:
|
753
|
-
return True
|
754
|
-
|
755
|
-
return wrapper
|
756
|
-
|
757
|
-
return _decorator
|
758
|
-
|
759
|
-
|
760
|
-
class HostKVCache(abc.ABC):
|
761
|
-
|
762
|
-
def __init__(
|
763
|
-
self,
|
764
|
-
device_pool: KVCache,
|
765
|
-
host_to_device_ratio: float,
|
766
|
-
host_size: int,
|
767
|
-
pin_memory: bool,
|
768
|
-
device: str,
|
769
|
-
page_size: int,
|
770
|
-
):
|
771
|
-
self.device_pool = device_pool
|
772
|
-
self.dtype = device_pool.store_dtype
|
773
|
-
self.pin_memory = pin_memory
|
774
|
-
self.device = device
|
775
|
-
self.page_size = page_size
|
776
|
-
self.size_per_token = self.get_size_per_token()
|
777
|
-
if host_size > 0:
|
778
|
-
self.size = int(host_size * 1e9 // self.size_per_token)
|
779
|
-
else:
|
780
|
-
self.size = int(device_pool.size * host_to_device_ratio)
|
781
|
-
# Align the host memory pool size to the page size
|
782
|
-
self.size = self.size - (self.size % self.page_size)
|
783
|
-
self.start_layer = device_pool.start_layer
|
784
|
-
self.end_layer = device_pool.end_layer
|
785
|
-
|
786
|
-
assert (
|
787
|
-
self.size > device_pool.size
|
788
|
-
), "The host memory should be larger than the device memory with the current protocol"
|
789
|
-
|
790
|
-
# Verify there is enough available host memory.
|
791
|
-
host_mem = psutil.virtual_memory()
|
792
|
-
requested_bytes = self.size * self.size_per_token
|
793
|
-
# preserve at least 10GB for other usage
|
794
|
-
ten_gb = 10 * (1024**3)
|
795
|
-
if requested_bytes > host_mem.available - ten_gb:
|
796
|
-
raise ValueError(
|
797
|
-
f"Not enough host memory available. Requesting "
|
798
|
-
f"{requested_bytes / 1e9:.2f} GB but only have "
|
799
|
-
f"{host_mem.available / 1e9:.2f} GB free. Please reduce the "
|
800
|
-
f"size of the hierarchical cache."
|
801
|
-
)
|
802
|
-
else:
|
803
|
-
logger.info(
|
804
|
-
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
805
|
-
)
|
806
|
-
|
807
|
-
self.kv_buffer = self.init_kv_buffer()
|
808
|
-
|
809
|
-
# A lock for synchronized operations on memory allocation and state transitions.
|
810
|
-
self.lock = threading.RLock()
|
811
|
-
self.debug = logger.isEnabledFor(logging.DEBUG)
|
812
|
-
self.clear()
|
813
|
-
|
814
|
-
@abc.abstractmethod
|
815
|
-
def get_size_per_token(self):
|
816
|
-
raise NotImplementedError()
|
817
|
-
|
818
|
-
@abc.abstractmethod
|
819
|
-
def init_kv_buffer(self):
|
820
|
-
raise NotImplementedError()
|
821
|
-
|
822
|
-
@abc.abstractmethod
|
823
|
-
def transfer(self, indices, flat_data):
|
824
|
-
raise NotImplementedError()
|
825
|
-
|
826
|
-
@abc.abstractmethod
|
827
|
-
def get_flat_data(self, indices):
|
828
|
-
raise NotImplementedError()
|
829
|
-
|
830
|
-
@abc.abstractmethod
|
831
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
832
|
-
raise NotImplementedError()
|
833
|
-
|
834
|
-
@abc.abstractmethod
|
835
|
-
def assign_flat_data(self, indices, flat_data):
|
836
|
-
raise NotImplementedError()
|
837
|
-
|
838
|
-
@synchronized()
|
839
|
-
def clear(self):
|
840
|
-
# Initialize memory states and tracking structures.
|
841
|
-
self.mem_state = torch.zeros(
|
842
|
-
(self.size,), dtype=torch.uint8, device=self.device
|
843
|
-
)
|
844
|
-
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
845
|
-
|
846
|
-
def available_size(self):
|
847
|
-
return len(self.free_slots)
|
848
|
-
|
849
|
-
@synchronized()
|
850
|
-
def alloc(self, need_size: int) -> torch.Tensor:
|
851
|
-
if need_size > self.available_size():
|
852
|
-
return None
|
853
|
-
|
854
|
-
select_index = self.free_slots[:need_size]
|
855
|
-
self.free_slots = self.free_slots[need_size:]
|
856
|
-
|
857
|
-
if self.debug:
|
858
|
-
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
859
|
-
|
860
|
-
return select_index
|
861
|
-
|
862
|
-
@synchronized()
|
863
|
-
def free(self, indices: torch.Tensor) -> int:
|
864
|
-
self.free_slots = torch.cat([self.free_slots, indices])
|
865
|
-
if self.debug:
|
866
|
-
self.mem_state[indices] = MemoryStateInt.IDLE
|
867
|
-
return len(indices)
|
868
|
-
|
869
|
-
@synchronized(debug_only=True)
|
870
|
-
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
871
|
-
assert len(indices) > 0, "The indices should not be empty"
|
872
|
-
states = self.mem_state[indices]
|
873
|
-
assert (
|
874
|
-
states == states[0]
|
875
|
-
).all(), "The memory slots should have the same state {}".format(states)
|
876
|
-
return MemoryStateInt(states[0].item())
|
877
|
-
|
878
|
-
@synchronized(debug_only=True)
|
879
|
-
def is_reserved(self, indices: torch.Tensor) -> bool:
|
880
|
-
return self.get_state(indices) == MemoryStateInt.RESERVED
|
881
|
-
|
882
|
-
@synchronized(debug_only=True)
|
883
|
-
def is_protected(self, indices: torch.Tensor) -> bool:
|
884
|
-
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
885
|
-
|
886
|
-
@synchronized(debug_only=True)
|
887
|
-
def is_synced(self, indices: torch.Tensor) -> bool:
|
888
|
-
return self.get_state(indices) == MemoryStateInt.SYNCED
|
889
|
-
|
890
|
-
@synchronized(debug_only=True)
|
891
|
-
def is_backup(self, indices: torch.Tensor) -> bool:
|
892
|
-
return self.get_state(indices) == MemoryStateInt.BACKUP
|
893
|
-
|
894
|
-
@synchronized(debug_only=True)
|
895
|
-
def update_backup(self, indices: torch.Tensor):
|
896
|
-
if not self.is_synced(indices):
|
897
|
-
raise ValueError(
|
898
|
-
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
899
|
-
f"Current state: {self.get_state(indices)}"
|
900
|
-
)
|
901
|
-
self.mem_state[indices] = MemoryStateInt.BACKUP
|
902
|
-
|
903
|
-
@synchronized(debug_only=True)
|
904
|
-
def update_synced(self, indices: torch.Tensor):
|
905
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
906
|
-
|
907
|
-
@synchronized(debug_only=True)
|
908
|
-
def protect_write(self, indices: torch.Tensor):
|
909
|
-
if not self.is_reserved(indices):
|
910
|
-
raise ValueError(
|
911
|
-
f"The host memory slots should be RESERVED before write operations. "
|
912
|
-
f"Current state: {self.get_state(indices)}"
|
913
|
-
)
|
914
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
915
|
-
|
916
|
-
@synchronized(debug_only=True)
|
917
|
-
def protect_load(self, indices: torch.Tensor):
|
918
|
-
if not self.is_backup(indices):
|
919
|
-
raise ValueError(
|
920
|
-
f"The host memory slots should be in BACKUP state before load operations. "
|
921
|
-
f"Current state: {self.get_state(indices)}"
|
922
|
-
)
|
923
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
924
|
-
|
925
|
-
@synchronized(debug_only=True)
|
926
|
-
def complete_io(self, indices: torch.Tensor):
|
927
|
-
if not self.is_protected(indices):
|
928
|
-
raise ValueError(
|
929
|
-
f"The host memory slots should be PROTECTED during I/O operations. "
|
930
|
-
f"Current state: {self.get_state(indices)}"
|
931
|
-
)
|
932
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
933
|
-
|
934
|
-
|
935
|
-
class MHATokenToKVPoolHost(HostKVCache):
|
936
|
-
device_pool: MHATokenToKVPool
|
937
|
-
|
938
|
-
def __init__(
|
939
|
-
self,
|
940
|
-
device_pool: MHATokenToKVPool,
|
941
|
-
host_to_device_ratio: float,
|
942
|
-
host_size: int,
|
943
|
-
page_size: int,
|
944
|
-
pin_memory: bool = True,
|
945
|
-
device: str = "cpu",
|
946
|
-
):
|
947
|
-
super().__init__(
|
948
|
-
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
949
|
-
)
|
950
|
-
|
951
|
-
def get_size_per_token(self):
|
952
|
-
self.head_num = self.device_pool.head_num
|
953
|
-
self.head_dim = self.device_pool.head_dim
|
954
|
-
self.layer_num = self.device_pool.layer_num
|
955
|
-
|
956
|
-
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
957
|
-
|
958
|
-
def init_kv_buffer(self):
|
959
|
-
return torch.empty(
|
960
|
-
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
961
|
-
dtype=self.dtype,
|
962
|
-
device=self.device,
|
963
|
-
pin_memory=self.pin_memory,
|
964
|
-
)
|
965
|
-
|
966
|
-
@debug_timing
|
967
|
-
def transfer(self, indices, flat_data):
|
968
|
-
# backup prepared data from device to host
|
969
|
-
self.kv_buffer[:, :, indices] = flat_data.to(
|
970
|
-
device=self.device, non_blocking=False
|
971
|
-
)
|
972
|
-
|
973
|
-
def get_flat_data(self, indices):
|
974
|
-
return self.kv_buffer[:, :, indices]
|
975
|
-
|
976
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
977
|
-
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
978
|
-
|
979
|
-
def assign_flat_data(self, indices, flat_data):
|
980
|
-
self.kv_buffer[:, :, indices] = flat_data
|
981
|
-
|
982
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
983
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
984
|
-
for i in range(len(device_indices_cpu)):
|
985
|
-
h_index = host_indices[i * self.page_size]
|
986
|
-
d_index = device_indices_cpu[i]
|
987
|
-
for j in range(self.layer_num):
|
988
|
-
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
989
|
-
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
990
|
-
non_blocking=True,
|
991
|
-
)
|
992
|
-
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
993
|
-
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
994
|
-
non_blocking=True,
|
995
|
-
)
|
996
|
-
|
997
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
998
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
999
|
-
for i in range(len(device_indices_cpu)):
|
1000
|
-
h_index = host_indices[i * self.page_size]
|
1001
|
-
d_index = device_indices_cpu[i]
|
1002
|
-
device_pool.k_buffer[layer_id - self.start_layer][
|
1003
|
-
d_index : d_index + self.page_size
|
1004
|
-
].copy_(
|
1005
|
-
self.kv_buffer[
|
1006
|
-
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
1007
|
-
],
|
1008
|
-
non_blocking=True,
|
1009
|
-
)
|
1010
|
-
device_pool.v_buffer[layer_id - self.start_layer][
|
1011
|
-
d_index : d_index + self.page_size
|
1012
|
-
].copy_(
|
1013
|
-
self.kv_buffer[
|
1014
|
-
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
1015
|
-
],
|
1016
|
-
non_blocking=True,
|
1017
|
-
)
|
773
|
+
@triton.jit
|
774
|
+
def copy_all_layer_kv_cache(
|
775
|
+
data_ptrs,
|
776
|
+
strides,
|
777
|
+
tgt_loc_ptr,
|
778
|
+
src_loc_ptr,
|
779
|
+
num_locs,
|
780
|
+
num_locs_upper: tl.constexpr,
|
781
|
+
):
|
782
|
+
BLOCK_SIZE: tl.constexpr = 128
|
1018
783
|
|
784
|
+
bid = tl.program_id(0)
|
785
|
+
stride = tl.load(strides + bid)
|
1019
786
|
|
1020
|
-
|
1021
|
-
|
787
|
+
data_ptr = tl.load(data_ptrs + bid)
|
788
|
+
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
1022
789
|
|
1023
|
-
|
1024
|
-
|
1025
|
-
|
1026
|
-
host_to_device_ratio: float,
|
1027
|
-
host_size: int,
|
1028
|
-
page_size: int,
|
1029
|
-
pin_memory: bool = True,
|
1030
|
-
device: str = "cpu",
|
1031
|
-
):
|
1032
|
-
super().__init__(
|
1033
|
-
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
1034
|
-
)
|
790
|
+
num_locs_offset = tl.arange(0, num_locs_upper)
|
791
|
+
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
792
|
+
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
1035
793
|
|
1036
|
-
|
1037
|
-
|
1038
|
-
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
1039
|
-
self.layer_num = self.device_pool.layer_num
|
794
|
+
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
795
|
+
# because this copy is an inplace operation.
|
1040
796
|
|
1041
|
-
|
1042
|
-
|
1043
|
-
|
1044
|
-
|
1045
|
-
|
797
|
+
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
798
|
+
for i in range(num_loop):
|
799
|
+
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
800
|
+
mask = (num_locs_offset < num_locs)[:, None] and (copy_offset < stride)[None, :]
|
801
|
+
value = tl.load(
|
802
|
+
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
1046
803
|
)
|
1047
|
-
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
self.layer_num,
|
1052
|
-
self.size,
|
1053
|
-
1,
|
1054
|
-
self.kv_lora_rank + self.qk_rope_head_dim,
|
1055
|
-
),
|
1056
|
-
dtype=self.dtype,
|
1057
|
-
device=self.device,
|
1058
|
-
pin_memory=self.pin_memory,
|
1059
|
-
)
|
1060
|
-
|
1061
|
-
@debug_timing
|
1062
|
-
def transfer(self, indices, flat_data):
|
1063
|
-
# backup prepared data from device to host
|
1064
|
-
self.kv_buffer[:, indices] = flat_data.to(
|
1065
|
-
device=self.device, non_blocking=False
|
804
|
+
tl.store(
|
805
|
+
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
806
|
+
value,
|
807
|
+
mask=mask,
|
1066
808
|
)
|
1067
|
-
|
1068
|
-
def get_flat_data(self, indices):
|
1069
|
-
return self.kv_buffer[:, indices]
|
1070
|
-
|
1071
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
1072
|
-
return self.kv_buffer[layer_id - self.start_layer, indices]
|
1073
|
-
|
1074
|
-
def assign_flat_data(self, indices, flat_data):
|
1075
|
-
self.kv_buffer[:, indices] = flat_data
|
1076
|
-
|
1077
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
1078
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
1079
|
-
for i in range(len(device_indices_cpu)):
|
1080
|
-
h_index = host_indices[i * self.page_size]
|
1081
|
-
d_index = device_indices_cpu[i]
|
1082
|
-
for j in range(self.layer_num):
|
1083
|
-
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
1084
|
-
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
1085
|
-
non_blocking=True,
|
1086
|
-
)
|
1087
|
-
|
1088
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
1089
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
1090
|
-
for i in range(len(device_indices_cpu)):
|
1091
|
-
h_index = host_indices[i * self.page_size]
|
1092
|
-
d_index = device_indices_cpu[i]
|
1093
|
-
device_pool.kv_buffer[layer_id - self.start_layer][
|
1094
|
-
d_index : d_index + self.page_size
|
1095
|
-
].copy_(
|
1096
|
-
self.kv_buffer[
|
1097
|
-
layer_id - self.start_layer, h_index : h_index + self.page_size
|
1098
|
-
],
|
1099
|
-
non_blocking=True,
|
1100
|
-
)
|