sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -22,12 +22,22 @@ from typing import TYPE_CHECKING, List, Optional
22
22
 
23
23
  import torch
24
24
 
25
+ from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
26
+
25
27
  if TYPE_CHECKING:
26
28
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
27
29
  from sglang.srt.mem_cache.memory_pool_host import HostKVCache
28
30
 
29
- from sglang.srt.distributed import get_tensor_model_parallel_rank
30
- from sglang.srt.mem_cache.memory_pool_host import MLATokenToKVPoolHost
31
+ from sglang.srt.distributed import (
32
+ get_tensor_model_parallel_rank,
33
+ get_tensor_model_parallel_world_size,
34
+ )
35
+ from sglang.srt.layers.dp_attention import (
36
+ get_attention_tp_rank,
37
+ get_attention_tp_size,
38
+ is_dp_attention_enabled,
39
+ )
40
+ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
31
41
 
32
42
  logger = logging.getLogger(__name__)
33
43
 
@@ -231,6 +241,8 @@ class HiCacheController:
231
241
  io_backend: str = "",
232
242
  storage_backend: Optional[str] = None,
233
243
  prefetch_threshold: int = 256,
244
+ model_name: Optional[str] = None,
245
+ storage_backend_extra_config: Optional[str] = None,
234
246
  ):
235
247
  self.mem_pool_device_allocator = token_to_kv_pool_allocator
236
248
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
@@ -238,30 +250,37 @@ class HiCacheController:
238
250
  self.write_policy = write_policy
239
251
  self.page_size = page_size
240
252
  self.io_backend = io_backend
241
-
242
253
  self.enable_storage = False
243
- self.is_mla = isinstance(self.mem_pool_host, MLATokenToKVPoolHost)
244
- # todo: move backend initialization to storage backend module
254
+
245
255
  if storage_backend is not None:
246
256
  self.storage_backend_type = storage_backend
247
- from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
257
+ from sglang.srt.mem_cache.hicache_storage import get_hash_str
258
+
259
+ self.get_hash_str = get_hash_str
260
+ self.storage_config = self._generate_storage_config(
261
+ model_name, storage_backend_extra_config
262
+ )
263
+ # for MLA models, only one rank needs to backup the KV cache
264
+ self.backup_skip = (
265
+ self.storage_config.is_mla_model
266
+ # todo: load balancing
267
+ and self.storage_config.tp_rank != 0
268
+ )
248
269
 
249
270
  if storage_backend == "file":
250
- self.storage_backend = HiCacheFile(is_mla=self.is_mla)
251
- self.get_hash_str = get_hash_str
271
+ from sglang.srt.mem_cache.hicache_storage import HiCacheFile
272
+
273
+ self.storage_backend = HiCacheFile(self.storage_config)
252
274
  elif storage_backend == "nixl":
253
275
  from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
254
276
 
255
277
  self.storage_backend = HiCacheNixl()
256
- self.get_hash_str = get_hash_str
257
278
  elif storage_backend == "mooncake":
258
279
  from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
259
280
  MooncakeStore,
260
- get_hash_str_mooncake,
261
281
  )
262
282
 
263
- self.storage_backend = MooncakeStore(is_mla=self.is_mla)
264
- self.get_hash_str = get_hash_str_mooncake
283
+ self.storage_backend = MooncakeStore(self.storage_config)
265
284
  self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
266
285
  assert self.mem_pool_host.layout == "page_first"
267
286
  elif storage_backend == "hf3fs":
@@ -279,19 +298,21 @@ class HiCacheController:
279
298
  )
280
299
  dtype = mem_pool_host.dtype
281
300
  self.storage_backend = HiCacheHF3FS.from_env_config(
282
- bytes_per_page, dtype
301
+ bytes_per_page, dtype, self.storage_config
283
302
  )
284
- self.get_hash_str = get_hash_str
285
303
  else:
286
304
  raise NotImplementedError(
287
305
  f"Unsupported storage backend: {storage_backend}"
288
306
  )
307
+
289
308
  self.enable_storage = True
290
309
  # todo: threshold policy for prefetching
291
310
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
292
311
  self.prefetch_capacity_limit = int(
293
312
  0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
294
313
  )
314
+ # granularity of batch storage IO operations, in number of pages
315
+ self.storage_batch_size = 128
295
316
  # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
296
317
  self.prefetch_tokens_occupied = 0
297
318
 
@@ -302,12 +323,6 @@ class HiCacheController:
302
323
  self.prefetch_tp_group = torch.distributed.new_group(
303
324
  group_ranks, backend="gloo"
304
325
  )
305
- self.prefetch_io_tp_group = torch.distributed.new_group(
306
- group_ranks, backend="gloo"
307
- )
308
- self.backup_tp_group = torch.distributed.new_group(
309
- group_ranks, backend="gloo"
310
- )
311
326
 
312
327
  self.load_cache_event = load_cache_event
313
328
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -357,10 +372,45 @@ class HiCacheController:
357
372
 
358
373
  self.prefetch_revoke_queue = Queue()
359
374
  self.ack_backup_queue = Queue()
375
+ self.host_mem_release_queue = Queue()
360
376
 
361
377
  self.prefetch_thread.start()
362
378
  self.backup_thread.start()
363
379
 
380
+ def _generate_storage_config(
381
+ self,
382
+ model_name: Optional[str] = None,
383
+ storage_backend_extra_config: Optional[str] = None,
384
+ ):
385
+
386
+ if is_dp_attention_enabled():
387
+ self.tp_rank = get_attention_tp_rank()
388
+ self.tp_size = get_attention_tp_size()
389
+ else:
390
+ self.tp_rank = get_tensor_model_parallel_rank()
391
+ self.tp_size = get_tensor_model_parallel_world_size()
392
+
393
+ # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
394
+ is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
395
+
396
+ # Parse extra config JSON if provided
397
+ extra_config = None
398
+ if storage_backend_extra_config:
399
+ try:
400
+ import json
401
+
402
+ extra_config = json.loads(storage_backend_extra_config)
403
+ except Exception as e:
404
+ logger.error(f"Invalid backend extra config JSON: {e}")
405
+
406
+ return HiCacheStorageConfig(
407
+ tp_rank=self.tp_rank,
408
+ tp_size=self.tp_size,
409
+ is_mla_model=is_mla_backend,
410
+ model_name=model_name,
411
+ extra_config=extra_config,
412
+ )
413
+
364
414
  def reset(self):
365
415
  self.stop_event.set()
366
416
  self.write_thread.join()
@@ -400,15 +450,6 @@ class HiCacheController:
400
450
  self.prefetch_thread.start()
401
451
  self.backup_thread.start()
402
452
 
403
- @property
404
- def backup_skip(self):
405
- return (
406
- self.is_mla
407
- and get_tensor_model_parallel_rank() != 0
408
- # todo: only support file and mooncake
409
- and self.storage_backend_type in ["file", "mooncake"]
410
- )
411
-
412
453
  def write(
413
454
  self,
414
455
  device_indices: torch.Tensor,
@@ -570,60 +611,93 @@ class HiCacheController:
570
611
  operation.mark_done()
571
612
  return operation.completed_tokens, operation.hash_value
572
613
 
573
- def zerocopy_page_transfer(self, operation, batch_size=8):
614
+ def append_host_mem_release(self, host_indices: torch.Tensor):
615
+ chunks = host_indices.split(self.mem_pool_host.page_size)
616
+ for chunk in chunks:
617
+ self.host_mem_release_queue.put(chunk)
618
+
619
+ def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
574
620
  hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
575
- operation.hash_value, operation.host_indices
621
+ hash_values, host_indices
576
622
  )
577
- for i in range(0, len(hashes), batch_size):
578
- page_hashes = hashes[i : i + batch_size]
579
- page_dsts = dsts[i : i + batch_size]
580
- page_data = self.storage_backend.batch_get(page_hashes, page_dsts)
581
- if page_data is None:
582
- logger.warning(
583
- f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
584
- )
585
- break
586
- completed_tokens = operation.completed_tokens
587
- if operation.increment(self.page_size * len(page_hashes)):
588
- for i in range(len(page_hashes)):
589
- completed_tokens += self.page_size
590
- else:
591
- break
623
+ page_data = self.storage_backend.batch_get(hashes, dsts)
624
+ if page_data:
625
+ operation.increment(self.page_size * len(hashes))
626
+ else:
627
+ logger.warning(
628
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
629
+ )
592
630
 
593
- def generic_page_transfer(self, operation, batch_size=8):
594
- for i in range(0, len(operation.hash_value), batch_size):
595
- page_hashes = operation.hash_value[i : i + batch_size]
596
- # todo: zero copy
597
- dummy_page_dst = [
598
- self.mem_pool_host.get_dummy_flat_data_page()
599
- for _ in range(len(page_hashes))
600
- ]
601
- page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
602
- if page_data is None:
631
+ def _mooncake_page_get(self, operation, hash_values, host_indices):
632
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
633
+ hash_values,
634
+ host_indices,
635
+ self.storage_config.tp_rank,
636
+ )
637
+ get_result = self.storage_backend.batch_get(
638
+ key_strs,
639
+ target_location=buffer_ptrs,
640
+ target_sizes=buffer_sizes,
641
+ )
642
+ if get_result != len(hash_values):
643
+ logger.warning(
644
+ f"Prefetch operation {operation.request_id} failed or partially failed."
645
+ )
646
+ if get_result != 0:
647
+ operation.increment(get_result * self.page_size)
648
+
649
+ def _generic_page_get(self, operation, hash_values, host_indices):
650
+ dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
651
+ hash_values
652
+ )
653
+ page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
654
+ if page_data is None:
655
+ return
656
+ for i in range(len(hash_values)):
657
+ if page_data[i] is None:
603
658
  logger.warning(
604
- f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
659
+ f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
605
660
  )
606
661
  break
607
- completed_tokens = operation.completed_tokens
608
- if operation.increment(self.page_size * len(page_hashes)):
609
- for i in range(len(page_hashes)):
610
- self.mem_pool_host.set_from_flat_data_page(
611
- operation.host_indices[completed_tokens],
612
- page_data[i],
613
- )
614
- completed_tokens += self.page_size
662
+ if operation.increment(self.page_size):
663
+ self.mem_pool_host.set_from_flat_data_page(
664
+ host_indices[i * self.page_size],
665
+ page_data[i],
666
+ )
615
667
  else:
616
668
  break
617
669
 
618
- def mooncake_page_transfer(self, operation):
619
- key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
620
- operation.hash_value, operation.host_indices
621
- )
622
- self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
623
- operation.increment(len(operation.hash_value) * self.page_size)
670
+ def _page_transfer(self, operation):
671
+ # Select the get function and batch size
672
+ if self.storage_backend_type == "mooncake":
673
+ get_func = self._mooncake_page_get
674
+ elif (
675
+ self.storage_backend_type == "hf3fs"
676
+ and self.mem_pool_host.layout == "page_first"
677
+ ):
678
+ get_func = self._3fs_zero_copy_page_get
679
+ else:
680
+ get_func = self._generic_page_get
624
681
 
625
- def is_mooncake_backend(self):
626
- return self.storage_backend_type == "mooncake"
682
+ # Transfer batch by batch
683
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
684
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
685
+ batch_host_indices = operation.host_indices[
686
+ i * self.page_size : (i + len(batch_hashes)) * self.page_size
687
+ ]
688
+ prev_completed_tokens = operation.completed_tokens
689
+ # Get one batch token, and update the completed_tokens if succeed
690
+ get_func(operation, batch_hashes, batch_host_indices)
691
+ # Check termination
692
+ if (
693
+ operation.completed_tokens
694
+ != prev_completed_tokens + len(batch_hashes) * self.page_size
695
+ ):
696
+ break # Some operations fail or operation terminated by controller
697
+ # release pre-allocated memory
698
+ self.append_host_mem_release(
699
+ operation.host_indices[operation.completed_tokens :]
700
+ )
627
701
 
628
702
  def prefetch_io_aux_func(self):
629
703
  """
@@ -632,35 +706,50 @@ class HiCacheController:
632
706
  while not self.stop_event.is_set():
633
707
  try:
634
708
  operation = self.prefetch_buffer.get(block=True, timeout=1)
635
- if self.is_mooncake_backend():
636
- self.mooncake_page_transfer(operation)
637
- elif self.storage_backend_type == "hf3fs":
638
- if self.mem_pool_host.layout == "page_first":
639
- self.zerocopy_page_transfer(operation, batch_size=128)
640
- elif self.mem_pool_host.layout == "layer_first":
641
- self.generic_page_transfer(operation, batch_size=128)
642
- else:
643
- self.generic_page_transfer(operation)
644
-
645
- if self.tp_world_size > 1:
646
- # to ensure all TP workers release the host memory at the same time
647
- torch.distributed.barrier(group=self.prefetch_io_tp_group)
709
+ self._page_transfer(operation)
648
710
  # operation terminated by controller, release pre-allocated memory
649
- self.mem_pool_host.free(
711
+ self.append_host_mem_release(
650
712
  operation.host_indices[operation.completed_tokens :]
651
713
  )
652
714
  except Empty:
653
715
  continue
654
716
 
655
- def prefetch_rate_limit_check(self) -> bool:
717
+ def prefetch_rate_limited(self) -> bool:
656
718
  """
657
719
  Rate limit the prefetching operations to avoid overwhelming the storage backend.
658
720
  """
659
721
  # cancel prefetch if too much memory is occupied
660
722
  if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
661
- return False
723
+ return True
662
724
  # todo: more sophisticated rate limiting based on storage backend performance
663
- return True
725
+ return False
726
+
727
+ def _storage_hit_query(self, operation) -> tuple[list[str], int]:
728
+ last_hash = operation.last_hash
729
+ tokens_to_fetch = operation.token_ids
730
+
731
+ storage_query_count = 0
732
+ hash_value = []
733
+
734
+ for start in range(
735
+ 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
736
+ ):
737
+ end = min(
738
+ start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
739
+ )
740
+ batch_tokens = tokens_to_fetch[start:end]
741
+ batch_hashes = []
742
+ for i in range(0, len(batch_tokens), self.page_size):
743
+ last_hash = self.get_hash_str(
744
+ batch_tokens[i : i + self.page_size], last_hash
745
+ )
746
+ batch_hashes.append(last_hash)
747
+ hit_page_num = self.storage_backend.batch_exists(batch_hashes)
748
+ hash_value.extend(batch_hashes[:hit_page_num])
749
+ storage_query_count += hit_page_num * self.page_size
750
+ if hit_page_num < len(batch_hashes):
751
+ break
752
+ return hash_value, storage_query_count
664
753
 
665
754
  def prefetch_thread_func(self):
666
755
  """
@@ -675,39 +764,7 @@ class HiCacheController:
675
764
  if operation is None:
676
765
  continue
677
766
 
678
- storage_hit_count = 0
679
- if (
680
- operation.host_indices is not None
681
- ) and self.prefetch_rate_limit_check():
682
- last_hash = operation.last_hash
683
- tokens_to_fetch = operation.token_ids
684
-
685
- remaining_tokens = len(tokens_to_fetch)
686
- hash_value = []
687
- while remaining_tokens >= self.page_size:
688
- last_hash = self.get_hash_str(
689
- tokens_to_fetch[
690
- storage_hit_count : storage_hit_count + self.page_size
691
- ],
692
- last_hash,
693
- )
694
-
695
- # todo, more unified interface
696
- if not self.is_mooncake_backend():
697
- if not self.storage_backend.exists(last_hash):
698
- break
699
- hash_value.append(last_hash)
700
- storage_hit_count += self.page_size
701
- remaining_tokens -= self.page_size
702
-
703
- if self.is_mooncake_backend():
704
- # deferring to batch exists for mooncake store
705
- exist_result = self.storage_backend.exists(hash_value)
706
- storage_hit_count = (
707
- sum(1 for v in exist_result.values() if v != 0)
708
- * self.page_size
709
- )
710
-
767
+ hash_value, storage_hit_count = self._storage_hit_query(operation)
711
768
  if self.tp_world_size > 1:
712
769
  storage_hit_count_tensor = torch.tensor(
713
770
  storage_hit_count, dtype=torch.int
@@ -722,8 +779,7 @@ class HiCacheController:
722
779
  if storage_hit_count < self.prefetch_threshold:
723
780
  # not to prefetch if not enough benefits
724
781
  self.prefetch_revoke_queue.put(operation.request_id)
725
- if operation.host_indices is not None:
726
- self.mem_pool_host.free(operation.host_indices)
782
+ self.append_host_mem_release(operation.host_indices)
727
783
  logger.debug(
728
784
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
729
785
  )
@@ -732,7 +788,9 @@ class HiCacheController:
732
788
  : (storage_hit_count // self.page_size)
733
789
  ]
734
790
  # free the pre-allocated memory for pages that are not hit
735
- self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
791
+ self.append_host_mem_release(
792
+ operation.host_indices[storage_hit_count:]
793
+ )
736
794
  operation.host_indices = operation.host_indices[:storage_hit_count]
737
795
  logger.debug(
738
796
  f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
@@ -755,59 +813,62 @@ class HiCacheController:
755
813
  self.backup_queue.put(operation)
756
814
  return operation.id
757
815
 
758
- def zerocopy_page_backup(self, operation, batch_size=8):
816
+ # non-zero copy
817
+ def _generic_page_set(self, hash_values, host_indices) -> bool:
818
+ data = [
819
+ self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
820
+ for i in range(len(hash_values))
821
+ ]
822
+ return self.storage_backend.batch_set(hash_values, data)
823
+
824
+ # zero copy
825
+ def _mooncake_page_set(self, hash_values, host_indices) -> bool:
826
+ key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
827
+ hash_values,
828
+ host_indices,
829
+ self.storage_config.tp_rank,
830
+ )
831
+ success = self.storage_backend.batch_set(
832
+ key_strs,
833
+ target_location=buffer_ptrs,
834
+ target_sizes=buffer_sizes,
835
+ )
836
+ return success
837
+
838
+ # zero copy
839
+ def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
759
840
  hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
760
- operation.hash_value, operation.host_indices
841
+ hash_values, host_indices
761
842
  )
762
- for i in range(0, len(hashes), batch_size):
763
- page_hashes = hashes[i : i + batch_size]
764
- page_data = dsts[i : i + batch_size]
765
- success = self.storage_backend.batch_set(page_hashes, page_data)
766
- if not success:
767
- logger.warning(f"Failed to write page {page_hashes} to storage.")
768
- break
769
- operation.completed_tokens += self.page_size * len(page_hashes)
770
-
771
- def generic_page_backup(self, operation, batch_size=8):
772
- for i in range(0, len(operation.hash_value), batch_size):
773
- page_hashes = operation.hash_value[i : i + batch_size]
774
- page_data = [
775
- self.mem_pool_host.get_flat_data_page(
776
- operation.host_indices[j * self.page_size]
777
- )
778
- for j in range(i, i + len(page_hashes))
843
+ return self.storage_backend.batch_set(hashes, dsts)
844
+
845
+ # Backup batch by batch
846
+ def _page_backup(self, operation):
847
+ # Select the set function and batch size
848
+ if self.storage_backend_type == "mooncake":
849
+ backup_set_func = self._mooncake_page_set
850
+ elif (
851
+ self.storage_backend_type == "hf3fs"
852
+ and self.mem_pool_host.layout == "page_first"
853
+ ):
854
+ backup_set_func = self._3fs_zero_copy_page_set
855
+ else:
856
+ backup_set_func = self._generic_page_set
857
+ # Backup batch by batch
858
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
859
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
860
+ batch_host_indices = operation.host_indices[
861
+ i * self.page_size : (i + len(batch_hashes)) * self.page_size
779
862
  ]
780
- success = self.storage_backend.batch_set(page_hashes, page_data)
863
+ # Set one batch token, and record if success.
864
+ # todo: allow partial success
865
+ success = backup_set_func(batch_hashes, batch_host_indices)
781
866
  if not success:
782
- logger.warning(f"Failed to write page {page_hashes} to storage.")
783
- break
784
- operation.completed_tokens += self.page_size * len(page_hashes)
785
-
786
- def mooncake_page_backup(self, operation):
787
- if len(operation.hash_value):
788
- exist_hashvalues = self.storage_backend.exists(operation.hash_value)
789
- indices = operation.host_indices.tolist()
790
- non_exist_keys = []
791
- non_exist_indices = []
792
- for i in range(len(operation.hash_value)):
793
- if not exist_hashvalues[operation.hash_value[i]]:
794
- non_exist_keys.append(operation.hash_value[i])
795
- non_exist_indices.extend(
796
- indices[i * self.page_size : (i + 1) * self.page_size]
797
- )
798
- if len(non_exist_keys) > 0:
799
- key_strs, buffer_ptrs, buffer_sizes = (
800
- self.mem_pool_host.get_buffer_meta(
801
- non_exist_keys, non_exist_indices
802
- )
803
- )
804
- # TODO: check the return value of batch set to see how many tokens are set successfully
805
- self.storage_backend.batch_set(
806
- key_strs,
807
- target_location=buffer_ptrs,
808
- target_sizes=buffer_sizes,
867
+ logger.warning(
868
+ f"Write page to storage: {len(batch_hashes)} pages failed."
809
869
  )
810
- operation.completed_tokens += len(operation.hash_value) * self.page_size
870
+ break
871
+ operation.completed_tokens += self.page_size * len(batch_hashes)
811
872
 
812
873
  def backup_thread_func(self):
813
874
  """
@@ -820,36 +881,8 @@ class HiCacheController:
820
881
  continue
821
882
 
822
883
  if not self.backup_skip:
823
- if self.is_mooncake_backend():
824
- self.mooncake_page_backup(operation)
825
- elif self.storage_backend_type == "hf3fs":
826
- if self.mem_pool_host.layout == "page_first":
827
- self.zerocopy_page_backup(operation, batch_size=128)
828
- elif self.mem_pool_host.layout == "layer_first":
829
- self.generic_page_backup(operation, batch_size=128)
830
- else:
831
- self.generic_page_backup(operation)
832
- min_completed_tokens = operation.completed_tokens
833
- else:
834
- min_completed_tokens = len(operation.token_ids)
835
-
836
- if self.tp_world_size > 1:
837
- completed_tokens_tensor = torch.tensor(
838
- min_completed_tokens, dtype=torch.int
839
- )
840
- torch.distributed.all_reduce(
841
- completed_tokens_tensor,
842
- op=torch.distributed.ReduceOp.MIN,
843
- group=self.backup_tp_group,
844
- )
845
- min_completed_tokens = completed_tokens_tensor.item()
846
-
847
- self.ack_backup_queue.put(
848
- (
849
- operation.id,
850
- min_completed_tokens,
851
- )
852
- )
884
+ self._page_backup(operation)
885
+ self.ack_backup_queue.put(operation.id)
853
886
 
854
887
  except Empty:
855
888
  continue