sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
19
19
|
Memory pool.
|
20
20
|
|
21
21
|
SGLang has two levels of memory pool.
|
22
|
-
ReqToTokenPool maps a
|
22
|
+
ReqToTokenPool maps a request to its token locations.
|
23
23
|
TokenToKVPoolAllocator manages the indices to kv cache data.
|
24
24
|
KVCache actually holds the physical kv cache.
|
25
25
|
"""
|
@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
|
|
115
115
|
) -> None:
|
116
116
|
raise NotImplementedError()
|
117
117
|
|
118
|
+
@abc.abstractmethod
|
119
|
+
def get_flat_data(self, indices):
|
120
|
+
raise NotImplementedError()
|
121
|
+
|
122
|
+
@abc.abstractmethod
|
123
|
+
def transfer(self, indices, flat_data):
|
124
|
+
raise NotImplementedError()
|
125
|
+
|
126
|
+
@abc.abstractmethod
|
127
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
128
|
+
raise NotImplementedError()
|
129
|
+
|
130
|
+
def register_layer_transfer_counter(self, layer_transfer_counter):
|
131
|
+
self.layer_transfer_counter = layer_transfer_counter
|
132
|
+
|
118
133
|
|
119
134
|
class TokenToKVPoolAllocator:
|
120
135
|
"""An allocator managing the indices to kv cache data."""
|
@@ -157,7 +172,7 @@ class TokenToKVPoolAllocator:
|
|
157
172
|
return
|
158
173
|
|
159
174
|
if self.is_not_in_free_group:
|
160
|
-
self.free_slots = torch.
|
175
|
+
self.free_slots = torch.cat((self.free_slots, free_index))
|
161
176
|
else:
|
162
177
|
self.free_group.append(free_index)
|
163
178
|
|
@@ -168,14 +183,14 @@ class TokenToKVPoolAllocator:
|
|
168
183
|
def free_group_end(self):
|
169
184
|
self.is_not_in_free_group = True
|
170
185
|
if self.free_group:
|
171
|
-
self.free(torch.
|
186
|
+
self.free(torch.cat(self.free_group))
|
172
187
|
|
173
188
|
def clear(self):
|
174
189
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
175
190
|
self.free_slots = torch.arange(
|
176
191
|
1, self.size + 1, dtype=torch.int64, device=self.device
|
177
192
|
)
|
178
|
-
self.
|
193
|
+
self.is_not_in_free_group = True
|
179
194
|
self.free_group = []
|
180
195
|
|
181
196
|
|
@@ -212,7 +227,8 @@ class MHATokenToKVPool(KVCache):
|
|
212
227
|
|
213
228
|
self.layer_transfer_counter = None
|
214
229
|
self.capture_mode = False
|
215
|
-
self.
|
230
|
+
self.device_module = torch.get_device_module(self.device)
|
231
|
+
self.alt_stream = self.device_module.Stream()
|
216
232
|
|
217
233
|
k_size, v_size = self.get_kv_size_bytes()
|
218
234
|
logger.info(
|
@@ -255,6 +271,19 @@ class MHATokenToKVPool(KVCache):
|
|
255
271
|
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
256
272
|
return k_size_bytes, v_size_bytes
|
257
273
|
|
274
|
+
# for disagg
|
275
|
+
def get_contiguous_buf_infos(self):
|
276
|
+
kv_data_ptrs = [
|
277
|
+
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
278
|
+
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
279
|
+
kv_data_lens = [
|
280
|
+
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
281
|
+
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
282
|
+
kv_item_lens = [
|
283
|
+
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
284
|
+
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
285
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
286
|
+
|
258
287
|
# Todo: different memory layout
|
259
288
|
def get_flat_data(self, indices):
|
260
289
|
# prepare a large chunk of contiguous data for efficient transfer
|
@@ -275,9 +304,6 @@ class MHATokenToKVPool(KVCache):
|
|
275
304
|
self.k_buffer[i][indices] = k_data[i]
|
276
305
|
self.v_buffer[i][indices] = v_data[i]
|
277
306
|
|
278
|
-
def register_layer_transfer_counter(self, layer_transfer_counter):
|
279
|
-
self.layer_transfer_counter = layer_transfer_counter
|
280
|
-
|
281
307
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
282
308
|
# transfer prepared data from host to device
|
283
309
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
@@ -327,11 +353,13 @@ class MHATokenToKVPool(KVCache):
|
|
327
353
|
cache_v = cache_v.view(self.store_dtype)
|
328
354
|
|
329
355
|
if self.capture_mode and cache_k.shape[0] < 4:
|
330
|
-
|
331
|
-
|
356
|
+
# Overlap the copy of K and V cache for small batch size
|
357
|
+
current_stream = self.device_module.current_stream()
|
358
|
+
self.alt_stream.wait_stream(current_stream)
|
359
|
+
with self.device_module.stream(self.alt_stream):
|
332
360
|
self.k_buffer[layer_id][loc] = cache_k
|
333
361
|
self.v_buffer[layer_id][loc] = cache_v
|
334
|
-
|
362
|
+
current_stream.wait_stream(self.alt_stream)
|
335
363
|
else:
|
336
364
|
self.k_buffer[layer_id][loc] = cache_k
|
337
365
|
self.v_buffer[layer_id][loc] = cache_v
|
@@ -388,6 +416,8 @@ class MLATokenToKVPool(KVCache):
|
|
388
416
|
else:
|
389
417
|
self.store_dtype = dtype
|
390
418
|
self.kv_lora_rank = kv_lora_rank
|
419
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
420
|
+
self.layer_num = layer_num
|
391
421
|
|
392
422
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
393
423
|
enable=enable_memory_saver
|
@@ -404,12 +434,20 @@ class MLATokenToKVPool(KVCache):
|
|
404
434
|
for _ in range(layer_num)
|
405
435
|
]
|
406
436
|
|
437
|
+
self.layer_transfer_counter = None
|
438
|
+
|
407
439
|
def get_key_buffer(self, layer_id: int):
|
440
|
+
if self.layer_transfer_counter is not None:
|
441
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
442
|
+
|
408
443
|
if self.store_dtype != self.dtype:
|
409
444
|
return self.kv_buffer[layer_id].view(self.dtype)
|
410
445
|
return self.kv_buffer[layer_id]
|
411
446
|
|
412
447
|
def get_value_buffer(self, layer_id: int):
|
448
|
+
if self.layer_transfer_counter is not None:
|
449
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
450
|
+
|
413
451
|
if self.store_dtype != self.dtype:
|
414
452
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
415
453
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
@@ -432,6 +470,22 @@ class MLATokenToKVPool(KVCache):
|
|
432
470
|
else:
|
433
471
|
self.kv_buffer[layer_id][loc] = cache_k
|
434
472
|
|
473
|
+
def get_flat_data(self, indices):
|
474
|
+
# prepare a large chunk of contiguous data for efficient transfer
|
475
|
+
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
476
|
+
|
477
|
+
@debug_timing
|
478
|
+
def transfer(self, indices, flat_data):
|
479
|
+
# transfer prepared data from host to device
|
480
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
481
|
+
for i in range(self.layer_num):
|
482
|
+
self.kv_buffer[i][indices] = flat_data[i]
|
483
|
+
|
484
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
485
|
+
# transfer prepared data from host to device
|
486
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
487
|
+
self.kv_buffer[layer_id][indices] = flat_data
|
488
|
+
|
435
489
|
|
436
490
|
class DoubleSparseTokenToKVPool(KVCache):
|
437
491
|
def __init__(
|
@@ -508,6 +562,15 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
508
562
|
self.v_buffer[layer_id][loc] = cache_v
|
509
563
|
self.label_buffer[layer_id][loc] = cache_label
|
510
564
|
|
565
|
+
def get_flat_data(self, indices):
|
566
|
+
pass
|
567
|
+
|
568
|
+
def transfer(self, indices, flat_data):
|
569
|
+
pass
|
570
|
+
|
571
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
572
|
+
pass
|
573
|
+
|
511
574
|
|
512
575
|
class MemoryStateInt(IntEnum):
|
513
576
|
IDLE = 0
|
@@ -517,21 +580,28 @@ class MemoryStateInt(IntEnum):
|
|
517
580
|
BACKUP = 4
|
518
581
|
|
519
582
|
|
520
|
-
def synchronized(
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
583
|
+
def synchronized(debug_only=False):
|
584
|
+
def _decorator(func):
|
585
|
+
@wraps(func)
|
586
|
+
def wrapper(self, *args, **kwargs):
|
587
|
+
if (not debug_only) or self.debug:
|
588
|
+
return func(self, *args, **kwargs)
|
589
|
+
with self.lock:
|
590
|
+
return func(self, *args, **kwargs)
|
591
|
+
else:
|
592
|
+
return True
|
525
593
|
|
526
|
-
|
594
|
+
return wrapper
|
527
595
|
|
596
|
+
return _decorator
|
528
597
|
|
529
|
-
|
598
|
+
|
599
|
+
class HostKVCache(abc.ABC):
|
530
600
|
|
531
601
|
def __init__(
|
532
602
|
self,
|
533
603
|
device_pool: MHATokenToKVPool,
|
534
|
-
host_to_device_ratio: float
|
604
|
+
host_to_device_ratio: float,
|
535
605
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
536
606
|
device: str = "cpu",
|
537
607
|
):
|
@@ -547,12 +617,7 @@ class MHATokenToKVPoolHost:
|
|
547
617
|
|
548
618
|
self.size = int(device_pool.size * host_to_device_ratio)
|
549
619
|
self.dtype = device_pool.store_dtype
|
550
|
-
self.
|
551
|
-
self.head_dim = device_pool.head_dim
|
552
|
-
self.layer_num = device_pool.layer_num
|
553
|
-
self.size_per_token = (
|
554
|
-
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
555
|
-
)
|
620
|
+
self.size_per_token = self.get_size_per_token()
|
556
621
|
|
557
622
|
# Verify there is enough available host memory.
|
558
623
|
host_mem = psutil.virtual_memory()
|
@@ -571,126 +636,218 @@ class MHATokenToKVPoolHost:
|
|
571
636
|
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
572
637
|
)
|
573
638
|
|
574
|
-
self.kv_buffer =
|
575
|
-
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
576
|
-
dtype=self.dtype,
|
577
|
-
device=self.device,
|
578
|
-
pin_memory=self.pin_memory,
|
579
|
-
)
|
580
|
-
|
581
|
-
# Initialize memory states and tracking structures.
|
582
|
-
self.mem_state = torch.zeros(
|
583
|
-
(self.size,), dtype=torch.uint8, device=self.device
|
584
|
-
)
|
585
|
-
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
586
|
-
self.can_use_mem_size = self.size
|
639
|
+
self.kv_buffer = self.init_kv_buffer()
|
587
640
|
|
588
641
|
# A lock for synchronized operations on memory allocation and state transitions.
|
589
642
|
self.lock = threading.RLock()
|
643
|
+
self.debug = logger.isEnabledFor(logging.DEBUG)
|
644
|
+
self.clear()
|
590
645
|
|
646
|
+
@abc.abstractmethod
|
647
|
+
def get_size_per_token(self):
|
648
|
+
raise NotImplementedError()
|
649
|
+
|
650
|
+
@abc.abstractmethod
|
651
|
+
def init_kv_buffer(self):
|
652
|
+
raise NotImplementedError()
|
653
|
+
|
654
|
+
@abc.abstractmethod
|
655
|
+
def transfer(self, indices, flat_data):
|
656
|
+
raise NotImplementedError()
|
657
|
+
|
658
|
+
@abc.abstractmethod
|
591
659
|
def get_flat_data(self, indices):
|
592
|
-
|
660
|
+
raise NotImplementedError()
|
593
661
|
|
662
|
+
@abc.abstractmethod
|
594
663
|
def get_flat_data_by_layer(self, indices, layer_id):
|
595
|
-
|
664
|
+
raise NotImplementedError()
|
596
665
|
|
666
|
+
@abc.abstractmethod
|
597
667
|
def assign_flat_data(self, indices, flat_data):
|
598
|
-
|
599
|
-
|
600
|
-
@debug_timing
|
601
|
-
def transfer(self, indices, flat_data):
|
602
|
-
# backup prepared data from device to host
|
603
|
-
self.kv_buffer[:, :, indices] = flat_data.to(
|
604
|
-
device=self.device, non_blocking=False
|
605
|
-
)
|
668
|
+
raise NotImplementedError()
|
606
669
|
|
607
|
-
@synchronized
|
670
|
+
@synchronized()
|
608
671
|
def clear(self):
|
609
|
-
|
610
|
-
self.
|
611
|
-
|
672
|
+
# Initialize memory states and tracking structures.
|
673
|
+
self.mem_state = torch.zeros(
|
674
|
+
(self.size,), dtype=torch.uint8, device=self.device
|
675
|
+
)
|
676
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
612
677
|
|
613
|
-
|
614
|
-
|
615
|
-
assert len(indices) > 0, "The indices should not be empty"
|
616
|
-
states = self.mem_state[indices]
|
617
|
-
assert (
|
618
|
-
states == states[0]
|
619
|
-
).all(), "The memory slots should have the same state {}".format(states)
|
620
|
-
return MemoryStateInt(states[0].item())
|
678
|
+
def available_size(self):
|
679
|
+
return len(self.free_slots)
|
621
680
|
|
622
|
-
@synchronized
|
681
|
+
@synchronized()
|
623
682
|
def alloc(self, need_size: int) -> torch.Tensor:
|
624
|
-
if need_size > self.
|
683
|
+
if need_size > self.available_size():
|
625
684
|
return None
|
626
685
|
|
627
|
-
# todo: de-fragementation
|
628
686
|
select_index = self.free_slots[:need_size]
|
629
687
|
self.free_slots = self.free_slots[need_size:]
|
630
688
|
|
631
|
-
self.
|
632
|
-
|
689
|
+
if self.debug:
|
690
|
+
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
633
691
|
|
634
692
|
return select_index
|
635
693
|
|
636
|
-
@synchronized
|
694
|
+
@synchronized()
|
695
|
+
def free(self, indices: torch.Tensor) -> int:
|
696
|
+
self.free_slots = torch.cat([self.free_slots, indices])
|
697
|
+
if self.debug:
|
698
|
+
self.mem_state[indices] = MemoryStateInt.IDLE
|
699
|
+
return len(indices)
|
700
|
+
|
701
|
+
@synchronized(debug_only=True)
|
702
|
+
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
703
|
+
assert len(indices) > 0, "The indices should not be empty"
|
704
|
+
states = self.mem_state[indices]
|
705
|
+
assert (
|
706
|
+
states == states[0]
|
707
|
+
).all(), "The memory slots should have the same state {}".format(states)
|
708
|
+
return MemoryStateInt(states[0].item())
|
709
|
+
|
710
|
+
@synchronized(debug_only=True)
|
637
711
|
def is_reserved(self, indices: torch.Tensor) -> bool:
|
638
712
|
return self.get_state(indices) == MemoryStateInt.RESERVED
|
639
713
|
|
640
|
-
@synchronized
|
714
|
+
@synchronized(debug_only=True)
|
641
715
|
def is_protected(self, indices: torch.Tensor) -> bool:
|
642
716
|
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
643
717
|
|
644
|
-
@synchronized
|
718
|
+
@synchronized(debug_only=True)
|
645
719
|
def is_synced(self, indices: torch.Tensor) -> bool:
|
646
720
|
return self.get_state(indices) == MemoryStateInt.SYNCED
|
647
721
|
|
648
|
-
@synchronized
|
722
|
+
@synchronized(debug_only=True)
|
649
723
|
def is_backup(self, indices: torch.Tensor) -> bool:
|
650
724
|
return self.get_state(indices) == MemoryStateInt.BACKUP
|
651
725
|
|
652
|
-
@synchronized
|
726
|
+
@synchronized(debug_only=True)
|
653
727
|
def update_backup(self, indices: torch.Tensor):
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
728
|
+
if not self.is_synced(indices):
|
729
|
+
raise ValueError(
|
730
|
+
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
731
|
+
f"Current state: {self.get_state(indices)}"
|
732
|
+
)
|
658
733
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
659
734
|
|
660
|
-
@synchronized
|
735
|
+
@synchronized(debug_only=True)
|
661
736
|
def update_synced(self, indices: torch.Tensor):
|
662
737
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
663
738
|
|
664
|
-
@synchronized
|
739
|
+
@synchronized(debug_only=True)
|
665
740
|
def protect_write(self, indices: torch.Tensor):
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
741
|
+
if not self.is_reserved(indices):
|
742
|
+
raise ValueError(
|
743
|
+
f"The host memory slots should be RESERVED before write operations. "
|
744
|
+
f"Current state: {self.get_state(indices)}"
|
745
|
+
)
|
670
746
|
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
671
747
|
|
672
|
-
@synchronized
|
748
|
+
@synchronized(debug_only=True)
|
673
749
|
def protect_load(self, indices: torch.Tensor):
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
750
|
+
if not self.is_backup(indices):
|
751
|
+
raise ValueError(
|
752
|
+
f"The host memory slots should be in BACKUP state before load operations. "
|
753
|
+
f"Current state: {self.get_state(indices)}"
|
754
|
+
)
|
678
755
|
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
679
756
|
|
680
|
-
@synchronized
|
757
|
+
@synchronized(debug_only=True)
|
681
758
|
def complete_io(self, indices: torch.Tensor):
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
759
|
+
if not self.is_protected(indices):
|
760
|
+
raise ValueError(
|
761
|
+
f"The host memory slots should be PROTECTED during I/O operations. "
|
762
|
+
f"Current state: {self.get_state(indices)}"
|
763
|
+
)
|
686
764
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
687
765
|
|
688
|
-
def available_size(self):
|
689
|
-
return len(self.free_slots)
|
690
766
|
|
691
|
-
|
692
|
-
def
|
693
|
-
self
|
694
|
-
|
695
|
-
|
696
|
-
|
767
|
+
class MHATokenToKVPoolHost(HostKVCache):
|
768
|
+
def __init__(
|
769
|
+
self,
|
770
|
+
device_pool: MHATokenToKVPool,
|
771
|
+
host_to_device_ratio: float,
|
772
|
+
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
773
|
+
device: str = "cpu",
|
774
|
+
):
|
775
|
+
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
776
|
+
|
777
|
+
def get_size_per_token(self):
|
778
|
+
self.head_num = self.device_pool.head_num
|
779
|
+
self.head_dim = self.device_pool.head_dim
|
780
|
+
self.layer_num = self.device_pool.layer_num
|
781
|
+
|
782
|
+
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
783
|
+
|
784
|
+
def init_kv_buffer(self):
|
785
|
+
return torch.empty(
|
786
|
+
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
787
|
+
dtype=self.dtype,
|
788
|
+
device=self.device,
|
789
|
+
pin_memory=self.pin_memory,
|
790
|
+
)
|
791
|
+
|
792
|
+
@debug_timing
|
793
|
+
def transfer(self, indices, flat_data):
|
794
|
+
# backup prepared data from device to host
|
795
|
+
self.kv_buffer[:, :, indices] = flat_data.to(
|
796
|
+
device=self.device, non_blocking=False
|
797
|
+
)
|
798
|
+
|
799
|
+
def get_flat_data(self, indices):
|
800
|
+
return self.kv_buffer[:, :, indices]
|
801
|
+
|
802
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
803
|
+
return self.kv_buffer[:, layer_id, indices]
|
804
|
+
|
805
|
+
def assign_flat_data(self, indices, flat_data):
|
806
|
+
self.kv_buffer[:, :, indices] = flat_data
|
807
|
+
|
808
|
+
|
809
|
+
class MLATokenToKVPoolHost(HostKVCache):
|
810
|
+
def __init__(
|
811
|
+
self,
|
812
|
+
device_pool: MLATokenToKVPool,
|
813
|
+
host_to_device_ratio: float,
|
814
|
+
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
815
|
+
device: str = "cpu",
|
816
|
+
):
|
817
|
+
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
818
|
+
|
819
|
+
def get_size_per_token(self):
|
820
|
+
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
821
|
+
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
822
|
+
self.layer_num = self.device_pool.layer_num
|
823
|
+
|
824
|
+
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
|
825
|
+
|
826
|
+
def init_kv_buffer(self):
|
827
|
+
return torch.empty(
|
828
|
+
(
|
829
|
+
self.layer_num,
|
830
|
+
self.size,
|
831
|
+
1,
|
832
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
833
|
+
),
|
834
|
+
dtype=self.dtype,
|
835
|
+
device=self.device,
|
836
|
+
pin_memory=self.pin_memory,
|
837
|
+
)
|
838
|
+
|
839
|
+
@debug_timing
|
840
|
+
def transfer(self, indices, flat_data):
|
841
|
+
# backup prepared data from device to host
|
842
|
+
self.kv_buffer[:, indices] = flat_data.to(
|
843
|
+
device=self.device, non_blocking=False
|
844
|
+
)
|
845
|
+
|
846
|
+
def get_flat_data(self, indices):
|
847
|
+
return self.kv_buffer[:, indices]
|
848
|
+
|
849
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
850
|
+
return self.kv_buffer[layer_id, indices]
|
851
|
+
|
852
|
+
def assign_flat_data(self, indices, flat_data):
|
853
|
+
self.kv_buffer[:, indices] = flat_data
|
@@ -272,12 +272,12 @@ class PagedTokenToKVPoolAllocator:
|
|
272
272
|
def free_group_end(self):
|
273
273
|
self.is_not_in_free_group = True
|
274
274
|
if self.free_group:
|
275
|
-
self.free(torch.
|
275
|
+
self.free(torch.cat(self.free_group))
|
276
276
|
|
277
277
|
def clear(self):
|
278
278
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
279
279
|
self.free_pages = torch.arange(
|
280
280
|
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
281
281
|
)
|
282
|
-
self.
|
282
|
+
self.is_not_in_free_group = True
|
283
283
|
self.free_group = []
|
@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache):
|
|
140
140
|
return (
|
141
141
|
torch.empty(
|
142
142
|
(0,),
|
143
|
-
dtype=torch.
|
143
|
+
dtype=torch.int64,
|
144
144
|
device=self.device,
|
145
145
|
),
|
146
146
|
self.root_node,
|
@@ -152,9 +152,9 @@ class RadixCache(BasePrefixCache):
|
|
152
152
|
|
153
153
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
154
154
|
if value:
|
155
|
-
value = torch.
|
155
|
+
value = torch.cat(value)
|
156
156
|
else:
|
157
|
-
value = torch.empty((0,), dtype=torch.
|
157
|
+
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
158
158
|
return value, last_node
|
159
159
|
|
160
160
|
def insert(self, key: List, value=None):
|
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
|
|
317
317
|
_dfs_helper(child)
|
318
318
|
|
319
319
|
_dfs_helper(self.root_node)
|
320
|
-
return torch.
|
320
|
+
return torch.cat(values)
|
321
321
|
|
322
322
|
##### Internal Helper Functions #####
|
323
323
|
|