sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
19
19
|
Memory pool.
|
20
20
|
|
21
21
|
SGLang has two levels of memory pool.
|
22
|
-
ReqToTokenPool maps a
|
22
|
+
ReqToTokenPool maps a request to its token locations.
|
23
23
|
TokenToKVPoolAllocator manages the indices to kv cache data.
|
24
24
|
KVCache actually holds the physical kv cache.
|
25
25
|
"""
|
@@ -115,6 +115,21 @@ class KVCache(abc.ABC):
|
|
115
115
|
) -> None:
|
116
116
|
raise NotImplementedError()
|
117
117
|
|
118
|
+
@abc.abstractmethod
|
119
|
+
def get_flat_data(self, indices):
|
120
|
+
raise NotImplementedError()
|
121
|
+
|
122
|
+
@abc.abstractmethod
|
123
|
+
def transfer(self, indices, flat_data):
|
124
|
+
raise NotImplementedError()
|
125
|
+
|
126
|
+
@abc.abstractmethod
|
127
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
128
|
+
raise NotImplementedError()
|
129
|
+
|
130
|
+
def register_layer_transfer_counter(self, layer_transfer_counter):
|
131
|
+
self.layer_transfer_counter = layer_transfer_counter
|
132
|
+
|
118
133
|
|
119
134
|
class TokenToKVPoolAllocator:
|
120
135
|
"""An allocator managing the indices to kv cache data."""
|
@@ -157,7 +172,7 @@ class TokenToKVPoolAllocator:
|
|
157
172
|
return
|
158
173
|
|
159
174
|
if self.is_not_in_free_group:
|
160
|
-
self.free_slots = torch.
|
175
|
+
self.free_slots = torch.cat((self.free_slots, free_index))
|
161
176
|
else:
|
162
177
|
self.free_group.append(free_index)
|
163
178
|
|
@@ -168,14 +183,14 @@ class TokenToKVPoolAllocator:
|
|
168
183
|
def free_group_end(self):
|
169
184
|
self.is_not_in_free_group = True
|
170
185
|
if self.free_group:
|
171
|
-
self.free(torch.
|
186
|
+
self.free(torch.cat(self.free_group))
|
172
187
|
|
173
188
|
def clear(self):
|
174
189
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
175
190
|
self.free_slots = torch.arange(
|
176
191
|
1, self.size + 1, dtype=torch.int64, device=self.device
|
177
192
|
)
|
178
|
-
self.
|
193
|
+
self.is_not_in_free_group = True
|
179
194
|
self.free_group = []
|
180
195
|
|
181
196
|
|
@@ -212,7 +227,8 @@ class MHATokenToKVPool(KVCache):
|
|
212
227
|
|
213
228
|
self.layer_transfer_counter = None
|
214
229
|
self.capture_mode = False
|
215
|
-
self.
|
230
|
+
self.device_module = torch.get_device_module(self.device)
|
231
|
+
self.alt_stream = self.device_module.Stream()
|
216
232
|
|
217
233
|
k_size, v_size = self.get_kv_size_bytes()
|
218
234
|
logger.info(
|
@@ -255,6 +271,19 @@ class MHATokenToKVPool(KVCache):
|
|
255
271
|
v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize
|
256
272
|
return k_size_bytes, v_size_bytes
|
257
273
|
|
274
|
+
# for disagg
|
275
|
+
def get_contiguous_buf_infos(self):
|
276
|
+
kv_data_ptrs = [
|
277
|
+
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
278
|
+
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
279
|
+
kv_data_lens = [
|
280
|
+
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
281
|
+
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
282
|
+
kv_item_lens = [
|
283
|
+
self.get_key_buffer(i)[0].nbytes for i in range(self.layer_num)
|
284
|
+
] + [self.get_value_buffer(i)[0].nbytes for i in range(self.layer_num)]
|
285
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
286
|
+
|
258
287
|
# Todo: different memory layout
|
259
288
|
def get_flat_data(self, indices):
|
260
289
|
# prepare a large chunk of contiguous data for efficient transfer
|
@@ -275,9 +304,6 @@ class MHATokenToKVPool(KVCache):
|
|
275
304
|
self.k_buffer[i][indices] = k_data[i]
|
276
305
|
self.v_buffer[i][indices] = v_data[i]
|
277
306
|
|
278
|
-
def register_layer_transfer_counter(self, layer_transfer_counter):
|
279
|
-
self.layer_transfer_counter = layer_transfer_counter
|
280
|
-
|
281
307
|
def transfer_per_layer(self, indices, flat_data, layer_id):
|
282
308
|
# transfer prepared data from host to device
|
283
309
|
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
@@ -326,12 +352,14 @@ class MHATokenToKVPool(KVCache):
|
|
326
352
|
cache_k = cache_k.view(self.store_dtype)
|
327
353
|
cache_v = cache_v.view(self.store_dtype)
|
328
354
|
|
329
|
-
if self.capture_mode:
|
330
|
-
|
331
|
-
|
355
|
+
if self.capture_mode and cache_k.shape[0] < 4:
|
356
|
+
# Overlap the copy of K and V cache for small batch size
|
357
|
+
current_stream = self.device_module.current_stream()
|
358
|
+
self.alt_stream.wait_stream(current_stream)
|
359
|
+
with self.device_module.stream(self.alt_stream):
|
332
360
|
self.k_buffer[layer_id][loc] = cache_k
|
333
361
|
self.v_buffer[layer_id][loc] = cache_v
|
334
|
-
|
362
|
+
current_stream.wait_stream(self.alt_stream)
|
335
363
|
else:
|
336
364
|
self.k_buffer[layer_id][loc] = cache_k
|
337
365
|
self.v_buffer[layer_id][loc] = cache_v
|
@@ -388,6 +416,8 @@ class MLATokenToKVPool(KVCache):
|
|
388
416
|
else:
|
389
417
|
self.store_dtype = dtype
|
390
418
|
self.kv_lora_rank = kv_lora_rank
|
419
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
420
|
+
self.layer_num = layer_num
|
391
421
|
|
392
422
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
393
423
|
enable=enable_memory_saver
|
@@ -404,12 +434,20 @@ class MLATokenToKVPool(KVCache):
|
|
404
434
|
for _ in range(layer_num)
|
405
435
|
]
|
406
436
|
|
437
|
+
self.layer_transfer_counter = None
|
438
|
+
|
407
439
|
def get_key_buffer(self, layer_id: int):
|
440
|
+
if self.layer_transfer_counter is not None:
|
441
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
442
|
+
|
408
443
|
if self.store_dtype != self.dtype:
|
409
444
|
return self.kv_buffer[layer_id].view(self.dtype)
|
410
445
|
return self.kv_buffer[layer_id]
|
411
446
|
|
412
447
|
def get_value_buffer(self, layer_id: int):
|
448
|
+
if self.layer_transfer_counter is not None:
|
449
|
+
self.layer_transfer_counter.wait_until(layer_id)
|
450
|
+
|
413
451
|
if self.store_dtype != self.dtype:
|
414
452
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
415
453
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
@@ -432,6 +470,22 @@ class MLATokenToKVPool(KVCache):
|
|
432
470
|
else:
|
433
471
|
self.kv_buffer[layer_id][loc] = cache_k
|
434
472
|
|
473
|
+
def get_flat_data(self, indices):
|
474
|
+
# prepare a large chunk of contiguous data for efficient transfer
|
475
|
+
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
476
|
+
|
477
|
+
@debug_timing
|
478
|
+
def transfer(self, indices, flat_data):
|
479
|
+
# transfer prepared data from host to device
|
480
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
481
|
+
for i in range(self.layer_num):
|
482
|
+
self.kv_buffer[i][indices] = flat_data[i]
|
483
|
+
|
484
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
485
|
+
# transfer prepared data from host to device
|
486
|
+
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
487
|
+
self.kv_buffer[layer_id][indices] = flat_data
|
488
|
+
|
435
489
|
|
436
490
|
class DoubleSparseTokenToKVPool(KVCache):
|
437
491
|
def __init__(
|
@@ -508,6 +562,15 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
508
562
|
self.v_buffer[layer_id][loc] = cache_v
|
509
563
|
self.label_buffer[layer_id][loc] = cache_label
|
510
564
|
|
565
|
+
def get_flat_data(self, indices):
|
566
|
+
pass
|
567
|
+
|
568
|
+
def transfer(self, indices, flat_data):
|
569
|
+
pass
|
570
|
+
|
571
|
+
def transfer_per_layer(self, indices, flat_data, layer_id):
|
572
|
+
pass
|
573
|
+
|
511
574
|
|
512
575
|
class MemoryStateInt(IntEnum):
|
513
576
|
IDLE = 0
|
@@ -517,21 +580,28 @@ class MemoryStateInt(IntEnum):
|
|
517
580
|
BACKUP = 4
|
518
581
|
|
519
582
|
|
520
|
-
def synchronized(
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
583
|
+
def synchronized(debug_only=False):
|
584
|
+
def _decorator(func):
|
585
|
+
@wraps(func)
|
586
|
+
def wrapper(self, *args, **kwargs):
|
587
|
+
if (not debug_only) or self.debug:
|
588
|
+
return func(self, *args, **kwargs)
|
589
|
+
with self.lock:
|
590
|
+
return func(self, *args, **kwargs)
|
591
|
+
else:
|
592
|
+
return True
|
525
593
|
|
526
|
-
|
594
|
+
return wrapper
|
527
595
|
|
596
|
+
return _decorator
|
528
597
|
|
529
|
-
|
598
|
+
|
599
|
+
class HostKVCache(abc.ABC):
|
530
600
|
|
531
601
|
def __init__(
|
532
602
|
self,
|
533
603
|
device_pool: MHATokenToKVPool,
|
534
|
-
host_to_device_ratio: float
|
604
|
+
host_to_device_ratio: float,
|
535
605
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
536
606
|
device: str = "cpu",
|
537
607
|
):
|
@@ -547,12 +617,7 @@ class MHATokenToKVPoolHost:
|
|
547
617
|
|
548
618
|
self.size = int(device_pool.size * host_to_device_ratio)
|
549
619
|
self.dtype = device_pool.store_dtype
|
550
|
-
self.
|
551
|
-
self.head_dim = device_pool.head_dim
|
552
|
-
self.layer_num = device_pool.layer_num
|
553
|
-
self.size_per_token = (
|
554
|
-
self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
555
|
-
)
|
620
|
+
self.size_per_token = self.get_size_per_token()
|
556
621
|
|
557
622
|
# Verify there is enough available host memory.
|
558
623
|
host_mem = psutil.virtual_memory()
|
@@ -571,123 +636,218 @@ class MHATokenToKVPoolHost:
|
|
571
636
|
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
572
637
|
)
|
573
638
|
|
574
|
-
self.kv_buffer =
|
575
|
-
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
576
|
-
dtype=self.dtype,
|
577
|
-
device=self.device,
|
578
|
-
pin_memory=self.pin_memory,
|
579
|
-
)
|
580
|
-
|
581
|
-
# Initialize memory states and tracking structures.
|
582
|
-
self.mem_state = torch.zeros(
|
583
|
-
(self.size,), dtype=torch.uint8, device=self.device
|
584
|
-
)
|
585
|
-
self.free_slots = torch.arange(self.size, dtype=torch.int32)
|
586
|
-
self.can_use_mem_size = self.size
|
639
|
+
self.kv_buffer = self.init_kv_buffer()
|
587
640
|
|
588
641
|
# A lock for synchronized operations on memory allocation and state transitions.
|
589
642
|
self.lock = threading.RLock()
|
643
|
+
self.debug = logger.isEnabledFor(logging.DEBUG)
|
644
|
+
self.clear()
|
590
645
|
|
591
|
-
|
592
|
-
|
646
|
+
@abc.abstractmethod
|
647
|
+
def get_size_per_token(self):
|
648
|
+
raise NotImplementedError()
|
593
649
|
|
594
|
-
|
595
|
-
|
650
|
+
@abc.abstractmethod
|
651
|
+
def init_kv_buffer(self):
|
652
|
+
raise NotImplementedError()
|
596
653
|
|
597
|
-
@
|
654
|
+
@abc.abstractmethod
|
598
655
|
def transfer(self, indices, flat_data):
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
656
|
+
raise NotImplementedError()
|
657
|
+
|
658
|
+
@abc.abstractmethod
|
659
|
+
def get_flat_data(self, indices):
|
660
|
+
raise NotImplementedError()
|
661
|
+
|
662
|
+
@abc.abstractmethod
|
663
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
664
|
+
raise NotImplementedError()
|
603
665
|
|
604
|
-
@
|
666
|
+
@abc.abstractmethod
|
667
|
+
def assign_flat_data(self, indices, flat_data):
|
668
|
+
raise NotImplementedError()
|
669
|
+
|
670
|
+
@synchronized()
|
605
671
|
def clear(self):
|
606
|
-
|
607
|
-
self.
|
608
|
-
|
672
|
+
# Initialize memory states and tracking structures.
|
673
|
+
self.mem_state = torch.zeros(
|
674
|
+
(self.size,), dtype=torch.uint8, device=self.device
|
675
|
+
)
|
676
|
+
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
609
677
|
|
610
|
-
|
611
|
-
|
612
|
-
assert len(indices) > 0, "The indices should not be empty"
|
613
|
-
states = self.mem_state[indices]
|
614
|
-
assert (
|
615
|
-
states == states[0]
|
616
|
-
).all(), "The memory slots should have the same state {}".format(states)
|
617
|
-
return MemoryStateInt(states[0].item())
|
678
|
+
def available_size(self):
|
679
|
+
return len(self.free_slots)
|
618
680
|
|
619
|
-
@synchronized
|
681
|
+
@synchronized()
|
620
682
|
def alloc(self, need_size: int) -> torch.Tensor:
|
621
|
-
if need_size > self.
|
683
|
+
if need_size > self.available_size():
|
622
684
|
return None
|
623
685
|
|
624
|
-
# todo: de-fragementation
|
625
686
|
select_index = self.free_slots[:need_size]
|
626
687
|
self.free_slots = self.free_slots[need_size:]
|
627
688
|
|
628
|
-
self.
|
629
|
-
|
689
|
+
if self.debug:
|
690
|
+
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
630
691
|
|
631
692
|
return select_index
|
632
693
|
|
633
|
-
@synchronized
|
694
|
+
@synchronized()
|
695
|
+
def free(self, indices: torch.Tensor) -> int:
|
696
|
+
self.free_slots = torch.cat([self.free_slots, indices])
|
697
|
+
if self.debug:
|
698
|
+
self.mem_state[indices] = MemoryStateInt.IDLE
|
699
|
+
return len(indices)
|
700
|
+
|
701
|
+
@synchronized(debug_only=True)
|
702
|
+
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
703
|
+
assert len(indices) > 0, "The indices should not be empty"
|
704
|
+
states = self.mem_state[indices]
|
705
|
+
assert (
|
706
|
+
states == states[0]
|
707
|
+
).all(), "The memory slots should have the same state {}".format(states)
|
708
|
+
return MemoryStateInt(states[0].item())
|
709
|
+
|
710
|
+
@synchronized(debug_only=True)
|
634
711
|
def is_reserved(self, indices: torch.Tensor) -> bool:
|
635
712
|
return self.get_state(indices) == MemoryStateInt.RESERVED
|
636
713
|
|
637
|
-
@synchronized
|
714
|
+
@synchronized(debug_only=True)
|
638
715
|
def is_protected(self, indices: torch.Tensor) -> bool:
|
639
716
|
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
640
717
|
|
641
|
-
@synchronized
|
718
|
+
@synchronized(debug_only=True)
|
642
719
|
def is_synced(self, indices: torch.Tensor) -> bool:
|
643
720
|
return self.get_state(indices) == MemoryStateInt.SYNCED
|
644
721
|
|
645
|
-
@synchronized
|
722
|
+
@synchronized(debug_only=True)
|
646
723
|
def is_backup(self, indices: torch.Tensor) -> bool:
|
647
724
|
return self.get_state(indices) == MemoryStateInt.BACKUP
|
648
725
|
|
649
|
-
@synchronized
|
726
|
+
@synchronized(debug_only=True)
|
650
727
|
def update_backup(self, indices: torch.Tensor):
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
728
|
+
if not self.is_synced(indices):
|
729
|
+
raise ValueError(
|
730
|
+
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
731
|
+
f"Current state: {self.get_state(indices)}"
|
732
|
+
)
|
655
733
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
656
734
|
|
657
|
-
@synchronized
|
735
|
+
@synchronized(debug_only=True)
|
658
736
|
def update_synced(self, indices: torch.Tensor):
|
659
737
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
660
738
|
|
661
|
-
@synchronized
|
739
|
+
@synchronized(debug_only=True)
|
662
740
|
def protect_write(self, indices: torch.Tensor):
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
741
|
+
if not self.is_reserved(indices):
|
742
|
+
raise ValueError(
|
743
|
+
f"The host memory slots should be RESERVED before write operations. "
|
744
|
+
f"Current state: {self.get_state(indices)}"
|
745
|
+
)
|
667
746
|
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
668
747
|
|
669
|
-
@synchronized
|
748
|
+
@synchronized(debug_only=True)
|
670
749
|
def protect_load(self, indices: torch.Tensor):
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
750
|
+
if not self.is_backup(indices):
|
751
|
+
raise ValueError(
|
752
|
+
f"The host memory slots should be in BACKUP state before load operations. "
|
753
|
+
f"Current state: {self.get_state(indices)}"
|
754
|
+
)
|
675
755
|
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
676
756
|
|
677
|
-
@synchronized
|
757
|
+
@synchronized(debug_only=True)
|
678
758
|
def complete_io(self, indices: torch.Tensor):
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
759
|
+
if not self.is_protected(indices):
|
760
|
+
raise ValueError(
|
761
|
+
f"The host memory slots should be PROTECTED during I/O operations. "
|
762
|
+
f"Current state: {self.get_state(indices)}"
|
763
|
+
)
|
683
764
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
684
765
|
|
685
|
-
def available_size(self):
|
686
|
-
return len(self.free_slots)
|
687
766
|
|
688
|
-
|
689
|
-
def
|
690
|
-
self
|
691
|
-
|
692
|
-
|
693
|
-
|
767
|
+
class MHATokenToKVPoolHost(HostKVCache):
|
768
|
+
def __init__(
|
769
|
+
self,
|
770
|
+
device_pool: MHATokenToKVPool,
|
771
|
+
host_to_device_ratio: float,
|
772
|
+
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
773
|
+
device: str = "cpu",
|
774
|
+
):
|
775
|
+
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
776
|
+
|
777
|
+
def get_size_per_token(self):
|
778
|
+
self.head_num = self.device_pool.head_num
|
779
|
+
self.head_dim = self.device_pool.head_dim
|
780
|
+
self.layer_num = self.device_pool.layer_num
|
781
|
+
|
782
|
+
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
783
|
+
|
784
|
+
def init_kv_buffer(self):
|
785
|
+
return torch.empty(
|
786
|
+
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
787
|
+
dtype=self.dtype,
|
788
|
+
device=self.device,
|
789
|
+
pin_memory=self.pin_memory,
|
790
|
+
)
|
791
|
+
|
792
|
+
@debug_timing
|
793
|
+
def transfer(self, indices, flat_data):
|
794
|
+
# backup prepared data from device to host
|
795
|
+
self.kv_buffer[:, :, indices] = flat_data.to(
|
796
|
+
device=self.device, non_blocking=False
|
797
|
+
)
|
798
|
+
|
799
|
+
def get_flat_data(self, indices):
|
800
|
+
return self.kv_buffer[:, :, indices]
|
801
|
+
|
802
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
803
|
+
return self.kv_buffer[:, layer_id, indices]
|
804
|
+
|
805
|
+
def assign_flat_data(self, indices, flat_data):
|
806
|
+
self.kv_buffer[:, :, indices] = flat_data
|
807
|
+
|
808
|
+
|
809
|
+
class MLATokenToKVPoolHost(HostKVCache):
|
810
|
+
def __init__(
|
811
|
+
self,
|
812
|
+
device_pool: MLATokenToKVPool,
|
813
|
+
host_to_device_ratio: float,
|
814
|
+
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
815
|
+
device: str = "cpu",
|
816
|
+
):
|
817
|
+
super().__init__(device_pool, host_to_device_ratio, pin_memory, device)
|
818
|
+
|
819
|
+
def get_size_per_token(self):
|
820
|
+
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
821
|
+
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
822
|
+
self.layer_num = self.device_pool.layer_num
|
823
|
+
|
824
|
+
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
|
825
|
+
|
826
|
+
def init_kv_buffer(self):
|
827
|
+
return torch.empty(
|
828
|
+
(
|
829
|
+
self.layer_num,
|
830
|
+
self.size,
|
831
|
+
1,
|
832
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
833
|
+
),
|
834
|
+
dtype=self.dtype,
|
835
|
+
device=self.device,
|
836
|
+
pin_memory=self.pin_memory,
|
837
|
+
)
|
838
|
+
|
839
|
+
@debug_timing
|
840
|
+
def transfer(self, indices, flat_data):
|
841
|
+
# backup prepared data from device to host
|
842
|
+
self.kv_buffer[:, indices] = flat_data.to(
|
843
|
+
device=self.device, non_blocking=False
|
844
|
+
)
|
845
|
+
|
846
|
+
def get_flat_data(self, indices):
|
847
|
+
return self.kv_buffer[:, indices]
|
848
|
+
|
849
|
+
def get_flat_data_by_layer(self, indices, layer_id):
|
850
|
+
return self.kv_buffer[layer_id, indices]
|
851
|
+
|
852
|
+
def assign_flat_data(self, indices, flat_data):
|
853
|
+
self.kv_buffer[:, indices] = flat_data
|
@@ -272,12 +272,12 @@ class PagedTokenToKVPoolAllocator:
|
|
272
272
|
def free_group_end(self):
|
273
273
|
self.is_not_in_free_group = True
|
274
274
|
if self.free_group:
|
275
|
-
self.free(torch.
|
275
|
+
self.free(torch.cat(self.free_group))
|
276
276
|
|
277
277
|
def clear(self):
|
278
278
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
279
279
|
self.free_pages = torch.arange(
|
280
280
|
1, self.num_pages + 1, dtype=torch.int64, device=self.device
|
281
281
|
)
|
282
|
-
self.
|
282
|
+
self.is_not_in_free_group = True
|
283
283
|
self.free_group = []
|
@@ -140,7 +140,7 @@ class RadixCache(BasePrefixCache):
|
|
140
140
|
return (
|
141
141
|
torch.empty(
|
142
142
|
(0,),
|
143
|
-
dtype=torch.
|
143
|
+
dtype=torch.int64,
|
144
144
|
device=self.device,
|
145
145
|
),
|
146
146
|
self.root_node,
|
@@ -152,9 +152,9 @@ class RadixCache(BasePrefixCache):
|
|
152
152
|
|
153
153
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
154
154
|
if value:
|
155
|
-
value = torch.
|
155
|
+
value = torch.cat(value)
|
156
156
|
else:
|
157
|
-
value = torch.empty((0,), dtype=torch.
|
157
|
+
value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
158
158
|
return value, last_node
|
159
159
|
|
160
160
|
def insert(self, key: List, value=None):
|
@@ -317,7 +317,7 @@ class RadixCache(BasePrefixCache):
|
|
317
317
|
_dfs_helper(child)
|
318
318
|
|
319
319
|
_dfs_helper(self.root_node)
|
320
|
-
return torch.
|
320
|
+
return torch.cat(values)
|
321
321
|
|
322
322
|
##### Internal Helper Functions #####
|
323
323
|
|