sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,22 @@ from typing import TYPE_CHECKING, List, Optional
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
26
|
+
|
25
27
|
if TYPE_CHECKING:
|
26
28
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
27
29
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
28
30
|
|
29
|
-
from sglang.srt.distributed import
|
30
|
-
|
31
|
+
from sglang.srt.distributed import (
|
32
|
+
get_tensor_model_parallel_rank,
|
33
|
+
get_tensor_model_parallel_world_size,
|
34
|
+
)
|
35
|
+
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_tp_rank,
|
37
|
+
get_attention_tp_size,
|
38
|
+
is_dp_attention_enabled,
|
39
|
+
)
|
40
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
31
41
|
|
32
42
|
logger = logging.getLogger(__name__)
|
33
43
|
|
@@ -231,6 +241,8 @@ class HiCacheController:
|
|
231
241
|
io_backend: str = "",
|
232
242
|
storage_backend: Optional[str] = None,
|
233
243
|
prefetch_threshold: int = 256,
|
244
|
+
model_name: Optional[str] = None,
|
245
|
+
storage_backend_extra_config: Optional[str] = None,
|
234
246
|
):
|
235
247
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
236
248
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -238,30 +250,37 @@ class HiCacheController:
|
|
238
250
|
self.write_policy = write_policy
|
239
251
|
self.page_size = page_size
|
240
252
|
self.io_backend = io_backend
|
241
|
-
|
242
253
|
self.enable_storage = False
|
243
|
-
|
244
|
-
# todo: move backend initialization to storage backend module
|
254
|
+
|
245
255
|
if storage_backend is not None:
|
246
256
|
self.storage_backend_type = storage_backend
|
247
|
-
from sglang.srt.mem_cache.hicache_storage import
|
257
|
+
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
258
|
+
|
259
|
+
self.get_hash_str = get_hash_str
|
260
|
+
self.storage_config = self._generate_storage_config(
|
261
|
+
model_name, storage_backend_extra_config
|
262
|
+
)
|
263
|
+
# for MLA models, only one rank needs to backup the KV cache
|
264
|
+
self.backup_skip = (
|
265
|
+
self.storage_config.is_mla_model
|
266
|
+
# todo: load balancing
|
267
|
+
and self.storage_config.tp_rank != 0
|
268
|
+
)
|
248
269
|
|
249
270
|
if storage_backend == "file":
|
250
|
-
|
251
|
-
|
271
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
272
|
+
|
273
|
+
self.storage_backend = HiCacheFile(self.storage_config)
|
252
274
|
elif storage_backend == "nixl":
|
253
275
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
254
276
|
|
255
277
|
self.storage_backend = HiCacheNixl()
|
256
|
-
self.get_hash_str = get_hash_str
|
257
278
|
elif storage_backend == "mooncake":
|
258
279
|
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
259
280
|
MooncakeStore,
|
260
|
-
get_hash_str_mooncake,
|
261
281
|
)
|
262
282
|
|
263
|
-
self.storage_backend = MooncakeStore(
|
264
|
-
self.get_hash_str = get_hash_str_mooncake
|
283
|
+
self.storage_backend = MooncakeStore(self.storage_config)
|
265
284
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
266
285
|
assert self.mem_pool_host.layout == "page_first"
|
267
286
|
elif storage_backend == "hf3fs":
|
@@ -279,19 +298,21 @@ class HiCacheController:
|
|
279
298
|
)
|
280
299
|
dtype = mem_pool_host.dtype
|
281
300
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
282
|
-
bytes_per_page, dtype
|
301
|
+
bytes_per_page, dtype, self.storage_config
|
283
302
|
)
|
284
|
-
self.get_hash_str = get_hash_str
|
285
303
|
else:
|
286
304
|
raise NotImplementedError(
|
287
305
|
f"Unsupported storage backend: {storage_backend}"
|
288
306
|
)
|
307
|
+
|
289
308
|
self.enable_storage = True
|
290
309
|
# todo: threshold policy for prefetching
|
291
310
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
292
311
|
self.prefetch_capacity_limit = int(
|
293
312
|
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
294
313
|
)
|
314
|
+
# granularity of batch storage IO operations, in number of pages
|
315
|
+
self.storage_batch_size = 128
|
295
316
|
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
296
317
|
self.prefetch_tokens_occupied = 0
|
297
318
|
|
@@ -302,12 +323,6 @@ class HiCacheController:
|
|
302
323
|
self.prefetch_tp_group = torch.distributed.new_group(
|
303
324
|
group_ranks, backend="gloo"
|
304
325
|
)
|
305
|
-
self.prefetch_io_tp_group = torch.distributed.new_group(
|
306
|
-
group_ranks, backend="gloo"
|
307
|
-
)
|
308
|
-
self.backup_tp_group = torch.distributed.new_group(
|
309
|
-
group_ranks, backend="gloo"
|
310
|
-
)
|
311
326
|
|
312
327
|
self.load_cache_event = load_cache_event
|
313
328
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -357,10 +372,45 @@ class HiCacheController:
|
|
357
372
|
|
358
373
|
self.prefetch_revoke_queue = Queue()
|
359
374
|
self.ack_backup_queue = Queue()
|
375
|
+
self.host_mem_release_queue = Queue()
|
360
376
|
|
361
377
|
self.prefetch_thread.start()
|
362
378
|
self.backup_thread.start()
|
363
379
|
|
380
|
+
def _generate_storage_config(
|
381
|
+
self,
|
382
|
+
model_name: Optional[str] = None,
|
383
|
+
storage_backend_extra_config: Optional[str] = None,
|
384
|
+
):
|
385
|
+
|
386
|
+
if is_dp_attention_enabled():
|
387
|
+
self.tp_rank = get_attention_tp_rank()
|
388
|
+
self.tp_size = get_attention_tp_size()
|
389
|
+
else:
|
390
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
391
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
392
|
+
|
393
|
+
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
394
|
+
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
395
|
+
|
396
|
+
# Parse extra config JSON if provided
|
397
|
+
extra_config = None
|
398
|
+
if storage_backend_extra_config:
|
399
|
+
try:
|
400
|
+
import json
|
401
|
+
|
402
|
+
extra_config = json.loads(storage_backend_extra_config)
|
403
|
+
except Exception as e:
|
404
|
+
logger.error(f"Invalid backend extra config JSON: {e}")
|
405
|
+
|
406
|
+
return HiCacheStorageConfig(
|
407
|
+
tp_rank=self.tp_rank,
|
408
|
+
tp_size=self.tp_size,
|
409
|
+
is_mla_model=is_mla_backend,
|
410
|
+
model_name=model_name,
|
411
|
+
extra_config=extra_config,
|
412
|
+
)
|
413
|
+
|
364
414
|
def reset(self):
|
365
415
|
self.stop_event.set()
|
366
416
|
self.write_thread.join()
|
@@ -400,15 +450,6 @@ class HiCacheController:
|
|
400
450
|
self.prefetch_thread.start()
|
401
451
|
self.backup_thread.start()
|
402
452
|
|
403
|
-
@property
|
404
|
-
def backup_skip(self):
|
405
|
-
return (
|
406
|
-
self.is_mla
|
407
|
-
and get_tensor_model_parallel_rank() != 0
|
408
|
-
# todo: only support file and mooncake
|
409
|
-
and self.storage_backend_type in ["file", "mooncake"]
|
410
|
-
)
|
411
|
-
|
412
453
|
def write(
|
413
454
|
self,
|
414
455
|
device_indices: torch.Tensor,
|
@@ -570,60 +611,93 @@ class HiCacheController:
|
|
570
611
|
operation.mark_done()
|
571
612
|
return operation.completed_tokens, operation.hash_value
|
572
613
|
|
573
|
-
def
|
614
|
+
def append_host_mem_release(self, host_indices: torch.Tensor):
|
615
|
+
chunks = host_indices.split(self.mem_pool_host.page_size)
|
616
|
+
for chunk in chunks:
|
617
|
+
self.host_mem_release_queue.put(chunk)
|
618
|
+
|
619
|
+
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
574
620
|
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
575
|
-
|
621
|
+
hash_values, host_indices
|
576
622
|
)
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
)
|
585
|
-
break
|
586
|
-
completed_tokens = operation.completed_tokens
|
587
|
-
if operation.increment(self.page_size * len(page_hashes)):
|
588
|
-
for i in range(len(page_hashes)):
|
589
|
-
completed_tokens += self.page_size
|
590
|
-
else:
|
591
|
-
break
|
623
|
+
page_data = self.storage_backend.batch_get(hashes, dsts)
|
624
|
+
if page_data:
|
625
|
+
operation.increment(self.page_size * len(hashes))
|
626
|
+
else:
|
627
|
+
logger.warning(
|
628
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
629
|
+
)
|
592
630
|
|
593
|
-
def
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
631
|
+
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
632
|
+
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
633
|
+
hash_values,
|
634
|
+
host_indices,
|
635
|
+
self.storage_config.tp_rank,
|
636
|
+
)
|
637
|
+
get_result = self.storage_backend.batch_get(
|
638
|
+
key_strs,
|
639
|
+
target_location=buffer_ptrs,
|
640
|
+
target_sizes=buffer_sizes,
|
641
|
+
)
|
642
|
+
if get_result != len(hash_values):
|
643
|
+
logger.warning(
|
644
|
+
f"Prefetch operation {operation.request_id} failed or partially failed."
|
645
|
+
)
|
646
|
+
if get_result != 0:
|
647
|
+
operation.increment(get_result * self.page_size)
|
648
|
+
|
649
|
+
def _generic_page_get(self, operation, hash_values, host_indices):
|
650
|
+
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
651
|
+
hash_values
|
652
|
+
)
|
653
|
+
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
654
|
+
if page_data is None:
|
655
|
+
return
|
656
|
+
for i in range(len(hash_values)):
|
657
|
+
if page_data[i] is None:
|
603
658
|
logger.warning(
|
604
|
-
f"Prefetch operation {operation.request_id} failed to retrieve page {
|
659
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
605
660
|
)
|
606
661
|
break
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
page_data[i],
|
613
|
-
)
|
614
|
-
completed_tokens += self.page_size
|
662
|
+
if operation.increment(self.page_size):
|
663
|
+
self.mem_pool_host.set_from_flat_data_page(
|
664
|
+
host_indices[i * self.page_size],
|
665
|
+
page_data[i],
|
666
|
+
)
|
615
667
|
else:
|
616
668
|
break
|
617
669
|
|
618
|
-
def
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
670
|
+
def _page_transfer(self, operation):
|
671
|
+
# Select the get function and batch size
|
672
|
+
if self.storage_backend_type == "mooncake":
|
673
|
+
get_func = self._mooncake_page_get
|
674
|
+
elif (
|
675
|
+
self.storage_backend_type == "hf3fs"
|
676
|
+
and self.mem_pool_host.layout == "page_first"
|
677
|
+
):
|
678
|
+
get_func = self._3fs_zero_copy_page_get
|
679
|
+
else:
|
680
|
+
get_func = self._generic_page_get
|
624
681
|
|
625
|
-
|
626
|
-
|
682
|
+
# Transfer batch by batch
|
683
|
+
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
684
|
+
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
685
|
+
batch_host_indices = operation.host_indices[
|
686
|
+
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
687
|
+
]
|
688
|
+
prev_completed_tokens = operation.completed_tokens
|
689
|
+
# Get one batch token, and update the completed_tokens if succeed
|
690
|
+
get_func(operation, batch_hashes, batch_host_indices)
|
691
|
+
# Check termination
|
692
|
+
if (
|
693
|
+
operation.completed_tokens
|
694
|
+
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
695
|
+
):
|
696
|
+
break # Some operations fail or operation terminated by controller
|
697
|
+
# release pre-allocated memory
|
698
|
+
self.append_host_mem_release(
|
699
|
+
operation.host_indices[operation.completed_tokens :]
|
700
|
+
)
|
627
701
|
|
628
702
|
def prefetch_io_aux_func(self):
|
629
703
|
"""
|
@@ -632,35 +706,50 @@ class HiCacheController:
|
|
632
706
|
while not self.stop_event.is_set():
|
633
707
|
try:
|
634
708
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
635
|
-
|
636
|
-
self.mooncake_page_transfer(operation)
|
637
|
-
elif self.storage_backend_type == "hf3fs":
|
638
|
-
if self.mem_pool_host.layout == "page_first":
|
639
|
-
self.zerocopy_page_transfer(operation, batch_size=128)
|
640
|
-
elif self.mem_pool_host.layout == "layer_first":
|
641
|
-
self.generic_page_transfer(operation, batch_size=128)
|
642
|
-
else:
|
643
|
-
self.generic_page_transfer(operation)
|
644
|
-
|
645
|
-
if self.tp_world_size > 1:
|
646
|
-
# to ensure all TP workers release the host memory at the same time
|
647
|
-
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
709
|
+
self._page_transfer(operation)
|
648
710
|
# operation terminated by controller, release pre-allocated memory
|
649
|
-
self.
|
711
|
+
self.append_host_mem_release(
|
650
712
|
operation.host_indices[operation.completed_tokens :]
|
651
713
|
)
|
652
714
|
except Empty:
|
653
715
|
continue
|
654
716
|
|
655
|
-
def
|
717
|
+
def prefetch_rate_limited(self) -> bool:
|
656
718
|
"""
|
657
719
|
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
658
720
|
"""
|
659
721
|
# cancel prefetch if too much memory is occupied
|
660
722
|
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
661
|
-
return
|
723
|
+
return True
|
662
724
|
# todo: more sophisticated rate limiting based on storage backend performance
|
663
|
-
return
|
725
|
+
return False
|
726
|
+
|
727
|
+
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
728
|
+
last_hash = operation.last_hash
|
729
|
+
tokens_to_fetch = operation.token_ids
|
730
|
+
|
731
|
+
storage_query_count = 0
|
732
|
+
hash_value = []
|
733
|
+
|
734
|
+
for start in range(
|
735
|
+
0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
|
736
|
+
):
|
737
|
+
end = min(
|
738
|
+
start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
|
739
|
+
)
|
740
|
+
batch_tokens = tokens_to_fetch[start:end]
|
741
|
+
batch_hashes = []
|
742
|
+
for i in range(0, len(batch_tokens), self.page_size):
|
743
|
+
last_hash = self.get_hash_str(
|
744
|
+
batch_tokens[i : i + self.page_size], last_hash
|
745
|
+
)
|
746
|
+
batch_hashes.append(last_hash)
|
747
|
+
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
748
|
+
hash_value.extend(batch_hashes[:hit_page_num])
|
749
|
+
storage_query_count += hit_page_num * self.page_size
|
750
|
+
if hit_page_num < len(batch_hashes):
|
751
|
+
break
|
752
|
+
return hash_value, storage_query_count
|
664
753
|
|
665
754
|
def prefetch_thread_func(self):
|
666
755
|
"""
|
@@ -675,39 +764,7 @@ class HiCacheController:
|
|
675
764
|
if operation is None:
|
676
765
|
continue
|
677
766
|
|
678
|
-
storage_hit_count =
|
679
|
-
if (
|
680
|
-
operation.host_indices is not None
|
681
|
-
) and self.prefetch_rate_limit_check():
|
682
|
-
last_hash = operation.last_hash
|
683
|
-
tokens_to_fetch = operation.token_ids
|
684
|
-
|
685
|
-
remaining_tokens = len(tokens_to_fetch)
|
686
|
-
hash_value = []
|
687
|
-
while remaining_tokens >= self.page_size:
|
688
|
-
last_hash = self.get_hash_str(
|
689
|
-
tokens_to_fetch[
|
690
|
-
storage_hit_count : storage_hit_count + self.page_size
|
691
|
-
],
|
692
|
-
last_hash,
|
693
|
-
)
|
694
|
-
|
695
|
-
# todo, more unified interface
|
696
|
-
if not self.is_mooncake_backend():
|
697
|
-
if not self.storage_backend.exists(last_hash):
|
698
|
-
break
|
699
|
-
hash_value.append(last_hash)
|
700
|
-
storage_hit_count += self.page_size
|
701
|
-
remaining_tokens -= self.page_size
|
702
|
-
|
703
|
-
if self.is_mooncake_backend():
|
704
|
-
# deferring to batch exists for mooncake store
|
705
|
-
exist_result = self.storage_backend.exists(hash_value)
|
706
|
-
storage_hit_count = (
|
707
|
-
sum(1 for v in exist_result.values() if v != 0)
|
708
|
-
* self.page_size
|
709
|
-
)
|
710
|
-
|
767
|
+
hash_value, storage_hit_count = self._storage_hit_query(operation)
|
711
768
|
if self.tp_world_size > 1:
|
712
769
|
storage_hit_count_tensor = torch.tensor(
|
713
770
|
storage_hit_count, dtype=torch.int
|
@@ -722,8 +779,7 @@ class HiCacheController:
|
|
722
779
|
if storage_hit_count < self.prefetch_threshold:
|
723
780
|
# not to prefetch if not enough benefits
|
724
781
|
self.prefetch_revoke_queue.put(operation.request_id)
|
725
|
-
|
726
|
-
self.mem_pool_host.free(operation.host_indices)
|
782
|
+
self.append_host_mem_release(operation.host_indices)
|
727
783
|
logger.debug(
|
728
784
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
729
785
|
)
|
@@ -732,7 +788,9 @@ class HiCacheController:
|
|
732
788
|
: (storage_hit_count // self.page_size)
|
733
789
|
]
|
734
790
|
# free the pre-allocated memory for pages that are not hit
|
735
|
-
self.
|
791
|
+
self.append_host_mem_release(
|
792
|
+
operation.host_indices[storage_hit_count:]
|
793
|
+
)
|
736
794
|
operation.host_indices = operation.host_indices[:storage_hit_count]
|
737
795
|
logger.debug(
|
738
796
|
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
@@ -755,59 +813,62 @@ class HiCacheController:
|
|
755
813
|
self.backup_queue.put(operation)
|
756
814
|
return operation.id
|
757
815
|
|
758
|
-
|
816
|
+
# non-zero copy
|
817
|
+
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
818
|
+
data = [
|
819
|
+
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
820
|
+
for i in range(len(hash_values))
|
821
|
+
]
|
822
|
+
return self.storage_backend.batch_set(hash_values, data)
|
823
|
+
|
824
|
+
# zero copy
|
825
|
+
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
826
|
+
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
827
|
+
hash_values,
|
828
|
+
host_indices,
|
829
|
+
self.storage_config.tp_rank,
|
830
|
+
)
|
831
|
+
success = self.storage_backend.batch_set(
|
832
|
+
key_strs,
|
833
|
+
target_location=buffer_ptrs,
|
834
|
+
target_sizes=buffer_sizes,
|
835
|
+
)
|
836
|
+
return success
|
837
|
+
|
838
|
+
# zero copy
|
839
|
+
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
759
840
|
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
760
|
-
|
841
|
+
hash_values, host_indices
|
761
842
|
)
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
843
|
+
return self.storage_backend.batch_set(hashes, dsts)
|
844
|
+
|
845
|
+
# Backup batch by batch
|
846
|
+
def _page_backup(self, operation):
|
847
|
+
# Select the set function and batch size
|
848
|
+
if self.storage_backend_type == "mooncake":
|
849
|
+
backup_set_func = self._mooncake_page_set
|
850
|
+
elif (
|
851
|
+
self.storage_backend_type == "hf3fs"
|
852
|
+
and self.mem_pool_host.layout == "page_first"
|
853
|
+
):
|
854
|
+
backup_set_func = self._3fs_zero_copy_page_set
|
855
|
+
else:
|
856
|
+
backup_set_func = self._generic_page_set
|
857
|
+
# Backup batch by batch
|
858
|
+
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
859
|
+
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
860
|
+
batch_host_indices = operation.host_indices[
|
861
|
+
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
779
862
|
]
|
780
|
-
|
863
|
+
# Set one batch token, and record if success.
|
864
|
+
# todo: allow partial success
|
865
|
+
success = backup_set_func(batch_hashes, batch_host_indices)
|
781
866
|
if not success:
|
782
|
-
logger.warning(
|
783
|
-
|
784
|
-
operation.completed_tokens += self.page_size * len(page_hashes)
|
785
|
-
|
786
|
-
def mooncake_page_backup(self, operation):
|
787
|
-
if len(operation.hash_value):
|
788
|
-
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
789
|
-
indices = operation.host_indices.tolist()
|
790
|
-
non_exist_keys = []
|
791
|
-
non_exist_indices = []
|
792
|
-
for i in range(len(operation.hash_value)):
|
793
|
-
if not exist_hashvalues[operation.hash_value[i]]:
|
794
|
-
non_exist_keys.append(operation.hash_value[i])
|
795
|
-
non_exist_indices.extend(
|
796
|
-
indices[i * self.page_size : (i + 1) * self.page_size]
|
797
|
-
)
|
798
|
-
if len(non_exist_keys) > 0:
|
799
|
-
key_strs, buffer_ptrs, buffer_sizes = (
|
800
|
-
self.mem_pool_host.get_buffer_meta(
|
801
|
-
non_exist_keys, non_exist_indices
|
802
|
-
)
|
803
|
-
)
|
804
|
-
# TODO: check the return value of batch set to see how many tokens are set successfully
|
805
|
-
self.storage_backend.batch_set(
|
806
|
-
key_strs,
|
807
|
-
target_location=buffer_ptrs,
|
808
|
-
target_sizes=buffer_sizes,
|
867
|
+
logger.warning(
|
868
|
+
f"Write page to storage: {len(batch_hashes)} pages failed."
|
809
869
|
)
|
810
|
-
|
870
|
+
break
|
871
|
+
operation.completed_tokens += self.page_size * len(batch_hashes)
|
811
872
|
|
812
873
|
def backup_thread_func(self):
|
813
874
|
"""
|
@@ -820,36 +881,8 @@ class HiCacheController:
|
|
820
881
|
continue
|
821
882
|
|
822
883
|
if not self.backup_skip:
|
823
|
-
|
824
|
-
|
825
|
-
elif self.storage_backend_type == "hf3fs":
|
826
|
-
if self.mem_pool_host.layout == "page_first":
|
827
|
-
self.zerocopy_page_backup(operation, batch_size=128)
|
828
|
-
elif self.mem_pool_host.layout == "layer_first":
|
829
|
-
self.generic_page_backup(operation, batch_size=128)
|
830
|
-
else:
|
831
|
-
self.generic_page_backup(operation)
|
832
|
-
min_completed_tokens = operation.completed_tokens
|
833
|
-
else:
|
834
|
-
min_completed_tokens = len(operation.token_ids)
|
835
|
-
|
836
|
-
if self.tp_world_size > 1:
|
837
|
-
completed_tokens_tensor = torch.tensor(
|
838
|
-
min_completed_tokens, dtype=torch.int
|
839
|
-
)
|
840
|
-
torch.distributed.all_reduce(
|
841
|
-
completed_tokens_tensor,
|
842
|
-
op=torch.distributed.ReduceOp.MIN,
|
843
|
-
group=self.backup_tp_group,
|
844
|
-
)
|
845
|
-
min_completed_tokens = completed_tokens_tensor.item()
|
846
|
-
|
847
|
-
self.ack_backup_queue.put(
|
848
|
-
(
|
849
|
-
operation.id,
|
850
|
-
min_completed_tokens,
|
851
|
-
)
|
852
|
-
)
|
884
|
+
self._page_backup(operation)
|
885
|
+
self.ack_backup_queue.put(operation.id)
|
853
886
|
|
854
887
|
except Empty:
|
855
888
|
continue
|