sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import TYPE_CHECKING, List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
|
|
33
33
|
get_tensor_model_parallel_world_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_dp_rank,
|
36
37
|
get_attention_tp_rank,
|
37
38
|
get_attention_tp_size,
|
38
39
|
is_dp_attention_enabled,
|
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class LayerLoadingEvent:
|
47
|
+
def __init__(self, num_layers: int):
|
48
|
+
self._num_layers = num_layers
|
49
|
+
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
50
|
+
self.start_event = torch.cuda.Event() # start event on controller stream
|
51
|
+
|
52
|
+
def complete(self, layer_index: int):
|
53
|
+
assert 0 <= layer_index < self._num_layers
|
54
|
+
self.load_events[layer_index].record()
|
55
|
+
|
56
|
+
def wait(self, layer_index: int):
|
57
|
+
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
58
|
+
|
59
|
+
@property
|
60
|
+
def finish_event(self):
|
61
|
+
return self.load_events[-1]
|
62
|
+
|
63
|
+
|
45
64
|
class LayerDoneCounter:
|
46
|
-
def __init__(self, num_layers):
|
65
|
+
def __init__(self, num_layers: int):
|
47
66
|
self.num_layers = num_layers
|
48
67
|
# extra producer and consumer counters for overlap mode
|
49
68
|
self.num_counters = 3
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
self.consumer_index = 0
|
54
|
-
|
55
|
-
def next_producer(self):
|
56
|
-
return (self.producer_index + 1) % self.num_counters
|
69
|
+
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
70
|
+
self.producer_index = -1
|
71
|
+
self.consumer_index = -1
|
57
72
|
|
58
73
|
def update_producer(self):
|
59
|
-
self.producer_index = self.
|
74
|
+
self.producer_index = (self.producer_index + 1) % self.num_counters
|
75
|
+
assert self.events[
|
76
|
+
self.producer_index
|
77
|
+
].finish_event.query(), (
|
78
|
+
"Producer finish event should be ready before being reused."
|
79
|
+
)
|
60
80
|
return self.producer_index
|
61
81
|
|
62
|
-
def set_consumer(self, index):
|
82
|
+
def set_consumer(self, index: int):
|
63
83
|
self.consumer_index = index
|
64
84
|
|
65
|
-
def
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
def wait_until(self, threshold):
|
71
|
-
with self.conditions[self.consumer_index]:
|
72
|
-
while self.counters[self.consumer_index] <= threshold:
|
73
|
-
self.conditions[self.consumer_index].wait()
|
85
|
+
def wait_until(self, threshold: int):
|
86
|
+
if self.consumer_index < 0:
|
87
|
+
return
|
88
|
+
self.events[self.consumer_index].wait(threshold)
|
74
89
|
|
75
90
|
def reset(self):
|
76
|
-
|
77
|
-
|
91
|
+
self.producer_index = -1
|
92
|
+
self.consumer_index = -1
|
78
93
|
|
79
94
|
|
80
95
|
class CacheOperation:
|
@@ -98,36 +113,30 @@ class CacheOperation:
|
|
98
113
|
# default priority is the order of creation
|
99
114
|
self.priority = priority if priority is not None else self.id
|
100
115
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
device_indices=self.device_indices[i : i + chunk_size],
|
120
|
-
node_id=0,
|
121
|
-
)
|
122
|
-
)
|
123
|
-
# Inherit the node_ids on the final chunk
|
124
|
-
if split_ops:
|
125
|
-
split_ops[-1].node_ids = self.node_ids
|
116
|
+
@staticmethod
|
117
|
+
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
118
|
+
assert len(ops) > 0
|
119
|
+
if len(ops) == 1:
|
120
|
+
return ops[0]
|
121
|
+
|
122
|
+
host_indices = torch.cat([op.host_indices for op in ops])
|
123
|
+
device_indices = torch.cat([op.device_indices for op in ops])
|
124
|
+
node_ids = []
|
125
|
+
priority = min(op.priority for op in ops)
|
126
|
+
for op in ops:
|
127
|
+
node_ids.extend(op.node_ids)
|
128
|
+
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
129
|
+
merged_op.node_ids = node_ids
|
130
|
+
return merged_op
|
131
|
+
|
132
|
+
def __lt__(self, other: CacheOperation):
|
133
|
+
return self.priority < other.priority
|
126
134
|
|
127
|
-
return split_ops
|
128
135
|
|
129
|
-
|
130
|
-
|
136
|
+
class HiCacheAck(NamedTuple):
|
137
|
+
start_event: torch.cuda.Event
|
138
|
+
finish_event: torch.cuda.Event
|
139
|
+
node_ids: List[int]
|
131
140
|
|
132
141
|
|
133
142
|
class TransferBuffer:
|
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
|
|
206
215
|
):
|
207
216
|
self.request_id = request_id
|
208
217
|
|
209
|
-
self._done_flag = False
|
210
218
|
self._lock = threading.Lock()
|
211
|
-
|
219
|
+
self._terminated_flag = False
|
212
220
|
self.start_time = time.monotonic()
|
213
221
|
|
214
222
|
super().__init__(host_indices, token_ids, last_hash)
|
215
223
|
|
216
224
|
def increment(self, num_tokens: int):
|
217
225
|
with self._lock:
|
218
|
-
if self.
|
226
|
+
if self._terminated_flag:
|
219
227
|
return False
|
220
228
|
self.completed_tokens += num_tokens
|
221
229
|
return True
|
222
230
|
|
223
|
-
def
|
231
|
+
def mark_terminate(self):
|
224
232
|
with self._lock:
|
225
|
-
self.
|
233
|
+
self._terminated_flag = True
|
226
234
|
|
227
|
-
def
|
228
|
-
return self.
|
235
|
+
def is_terminated(self) -> bool:
|
236
|
+
return self._terminated_flag
|
229
237
|
|
230
238
|
|
231
239
|
class HiCacheController:
|
@@ -236,7 +244,7 @@ class HiCacheController:
|
|
236
244
|
mem_pool_host: HostKVCache,
|
237
245
|
page_size: int,
|
238
246
|
tp_group: torch.distributed.ProcessGroup,
|
239
|
-
load_cache_event: threading.Event
|
247
|
+
load_cache_event: threading.Event,
|
240
248
|
write_policy: str = "write_through_selective",
|
241
249
|
io_backend: str = "",
|
242
250
|
storage_backend: Optional[str] = None,
|
@@ -324,8 +332,25 @@ class HiCacheController:
|
|
324
332
|
group_ranks, backend="gloo"
|
325
333
|
)
|
326
334
|
|
327
|
-
|
328
|
-
|
335
|
+
# Select the get and set functions
|
336
|
+
self.page_get_func = self._generic_page_get
|
337
|
+
self.page_set_func = self._generic_page_set
|
338
|
+
self.batch_exists_func = self.storage_backend.batch_exists
|
339
|
+
self.is_3fs_zerocopy = (
|
340
|
+
self.storage_backend_type == "hf3fs"
|
341
|
+
and self.mem_pool_host.layout == "page_first"
|
342
|
+
)
|
343
|
+
if self.storage_backend_type == "mooncake":
|
344
|
+
self.page_get_func = self._mooncake_page_get
|
345
|
+
self.page_set_func = self._mooncake_page_set
|
346
|
+
elif self.is_3fs_zerocopy:
|
347
|
+
self.page_get_func = self._3fs_zero_copy_page_get
|
348
|
+
self.page_set_func = self._3fs_zero_copy_page_set
|
349
|
+
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
350
|
+
|
351
|
+
self.device = self.mem_pool_device.device
|
352
|
+
self.layer_num = self.mem_pool_device.layer_num
|
353
|
+
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
329
354
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
330
355
|
|
331
356
|
if write_policy not in [
|
@@ -335,11 +360,11 @@ class HiCacheController:
|
|
335
360
|
]:
|
336
361
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
337
362
|
|
338
|
-
self.write_queue = PriorityQueue()
|
339
|
-
self.load_queue =
|
340
|
-
|
341
|
-
self.
|
342
|
-
self.
|
363
|
+
# self.write_queue = PriorityQueue[CacheOperation]()
|
364
|
+
self.load_queue: List[CacheOperation] = []
|
365
|
+
self.write_queue: List[CacheOperation] = []
|
366
|
+
self.ack_load_queue: List[HiCacheAck] = []
|
367
|
+
self.ack_write_queue: List[HiCacheAck] = []
|
343
368
|
|
344
369
|
self.stop_event = threading.Event()
|
345
370
|
self.write_buffer = TransferBuffer(self.stop_event)
|
@@ -350,16 +375,6 @@ class HiCacheController:
|
|
350
375
|
self.write_stream = torch.cuda.Stream()
|
351
376
|
self.load_stream = torch.cuda.Stream()
|
352
377
|
|
353
|
-
self.write_thread = threading.Thread(
|
354
|
-
target=self.write_thread_func_direct, daemon=True
|
355
|
-
)
|
356
|
-
self.load_thread = threading.Thread(
|
357
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
358
|
-
)
|
359
|
-
|
360
|
-
self.write_thread.start()
|
361
|
-
self.load_thread.start()
|
362
|
-
|
363
378
|
if self.enable_storage:
|
364
379
|
self.prefetch_thread = threading.Thread(
|
365
380
|
target=self.prefetch_thread_func, daemon=True
|
@@ -386,9 +401,11 @@ class HiCacheController:
|
|
386
401
|
if is_dp_attention_enabled():
|
387
402
|
self.tp_rank = get_attention_tp_rank()
|
388
403
|
self.tp_size = get_attention_tp_size()
|
404
|
+
self.dp_rank = get_attention_dp_rank()
|
389
405
|
else:
|
390
406
|
self.tp_rank = get_tensor_model_parallel_rank()
|
391
407
|
self.tp_size = get_tensor_model_parallel_world_size()
|
408
|
+
self.dp_rank = 0
|
392
409
|
|
393
410
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
394
411
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
@@ -407,21 +424,20 @@ class HiCacheController:
|
|
407
424
|
tp_rank=self.tp_rank,
|
408
425
|
tp_size=self.tp_size,
|
409
426
|
is_mla_model=is_mla_backend,
|
427
|
+
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
410
428
|
model_name=model_name,
|
411
429
|
extra_config=extra_config,
|
412
430
|
)
|
413
431
|
|
414
432
|
def reset(self):
|
415
433
|
self.stop_event.set()
|
416
|
-
self.write_thread.join()
|
417
|
-
self.load_thread.join()
|
418
434
|
|
419
|
-
self.write_queue.
|
420
|
-
self.load_queue.
|
435
|
+
self.write_queue.clear()
|
436
|
+
self.load_queue.clear()
|
421
437
|
self.write_buffer.clear()
|
422
438
|
self.load_buffer.clear()
|
423
|
-
self.ack_write_queue.
|
424
|
-
self.ack_load_queue.
|
439
|
+
self.ack_write_queue.clear()
|
440
|
+
self.ack_load_queue.clear()
|
425
441
|
if self.enable_storage:
|
426
442
|
self.prefetch_thread.join()
|
427
443
|
self.backup_thread.join()
|
@@ -430,15 +446,7 @@ class HiCacheController:
|
|
430
446
|
self.prefetch_revoke_queue.queue.clear()
|
431
447
|
self.ack_backup_queue.queue.clear()
|
432
448
|
|
433
|
-
self.write_thread = threading.Thread(
|
434
|
-
target=self.write_thread_func_direct, daemon=True
|
435
|
-
)
|
436
|
-
self.load_thread = threading.Thread(
|
437
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
438
|
-
)
|
439
449
|
self.stop_event.clear()
|
440
|
-
self.write_thread.start()
|
441
|
-
self.load_thread.start()
|
442
450
|
|
443
451
|
if self.enable_storage:
|
444
452
|
self.prefetch_thread = threading.Thread(
|
@@ -454,7 +462,7 @@ class HiCacheController:
|
|
454
462
|
self,
|
455
463
|
device_indices: torch.Tensor,
|
456
464
|
priority: Optional[int] = None,
|
457
|
-
node_id: int =
|
465
|
+
node_id: int = -1,
|
458
466
|
) -> Optional[torch.Tensor]:
|
459
467
|
"""
|
460
468
|
Back up KV caches from device memory to host memory.
|
@@ -463,17 +471,46 @@ class HiCacheController:
|
|
463
471
|
if host_indices is None:
|
464
472
|
return None
|
465
473
|
self.mem_pool_host.protect_write(host_indices)
|
466
|
-
|
467
|
-
self.write_queue.put(
|
474
|
+
self.write_queue.append(
|
468
475
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
469
476
|
)
|
477
|
+
self.start_writing()
|
470
478
|
return host_indices
|
471
479
|
|
480
|
+
def start_writing(self) -> None:
|
481
|
+
if len(self.write_queue) == 0:
|
482
|
+
return
|
483
|
+
|
484
|
+
op = CacheOperation.merge_ops(self.write_queue)
|
485
|
+
host_indices, device_indices = self.move_indices(op)
|
486
|
+
self.write_queue.clear()
|
487
|
+
|
488
|
+
start_event = torch.cuda.Event()
|
489
|
+
finish_event = torch.cuda.Event()
|
490
|
+
|
491
|
+
start_event.record()
|
492
|
+
with torch.cuda.stream(self.write_stream):
|
493
|
+
start_event.wait(self.write_stream)
|
494
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
495
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
496
|
+
)
|
497
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
498
|
+
finish_event.record()
|
499
|
+
# NOTE: We must save the host indices and device indices here,
|
500
|
+
# this is because we need to guarantee that these tensors are
|
501
|
+
# still alive when the write stream is executing.
|
502
|
+
if host_indices.is_cuda:
|
503
|
+
host_indices.record_stream(self.write_stream)
|
504
|
+
if device_indices.is_cuda:
|
505
|
+
device_indices.record_stream(self.write_stream)
|
506
|
+
|
507
|
+
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
508
|
+
|
472
509
|
def load(
|
473
510
|
self,
|
474
511
|
host_indices: torch.Tensor,
|
475
512
|
priority: Optional[int] = None,
|
476
|
-
node_id: int =
|
513
|
+
node_id: int = -1,
|
477
514
|
) -> Optional[torch.Tensor]:
|
478
515
|
"""
|
479
516
|
Load KV caches from host memory to device memory.
|
@@ -482,76 +519,42 @@ class HiCacheController:
|
|
482
519
|
if device_indices is None:
|
483
520
|
return None
|
484
521
|
self.mem_pool_host.protect_load(host_indices)
|
485
|
-
|
486
|
-
torch.cuda.current_stream().synchronize()
|
487
|
-
self.load_queue.put(
|
522
|
+
self.load_queue.append(
|
488
523
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
489
524
|
)
|
490
525
|
return device_indices
|
491
526
|
|
492
|
-
def move_indices(self,
|
527
|
+
def move_indices(self, op: CacheOperation):
|
528
|
+
host_indices, device_indices = op.host_indices, op.device_indices
|
493
529
|
# move indices to GPU if using kernels, to host if using direct indexing
|
494
530
|
if self.io_backend == "kernel":
|
495
|
-
|
531
|
+
if not host_indices.is_cuda:
|
532
|
+
host_indices = host_indices.to(self.device, non_blocking=True)
|
533
|
+
return host_indices, device_indices
|
496
534
|
elif self.io_backend == "direct":
|
497
|
-
|
498
|
-
|
499
|
-
|
535
|
+
if self.mem_pool_host.layout == "layer_first":
|
536
|
+
device_indices = device_indices.cpu()
|
537
|
+
host_indices, idx = host_indices.sort()
|
538
|
+
return host_indices, device_indices.index_select(0, idx)
|
539
|
+
elif self.mem_pool_host.layout == "page_first_direct":
|
540
|
+
return host_indices, device_indices.cpu()
|
500
541
|
else:
|
501
542
|
raise ValueError(f"Unsupported io backend")
|
502
543
|
|
503
|
-
def
|
504
|
-
|
505
|
-
|
506
|
-
"""
|
507
|
-
torch.cuda.set_stream(self.write_stream)
|
508
|
-
while not self.stop_event.is_set():
|
509
|
-
try:
|
510
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
511
|
-
host_indices, device_indices = self.move_indices(
|
512
|
-
operation.host_indices, operation.device_indices
|
513
|
-
)
|
514
|
-
self.mem_pool_host.backup_from_device_all_layer(
|
515
|
-
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
516
|
-
)
|
517
|
-
self.write_stream.synchronize()
|
518
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
519
|
-
for node_id in operation.node_ids:
|
520
|
-
if node_id != 0:
|
521
|
-
self.ack_write_queue.put(node_id)
|
522
|
-
except Empty:
|
523
|
-
continue
|
524
|
-
except Exception as e:
|
525
|
-
logger.error(e)
|
544
|
+
def start_loading(self) -> int:
|
545
|
+
if len(self.load_queue) == 0:
|
546
|
+
return -1
|
526
547
|
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
self.load_cache_event.wait(timeout=1)
|
534
|
-
if not self.load_cache_event.is_set():
|
535
|
-
continue
|
536
|
-
self.load_cache_event.clear()
|
537
|
-
self.layer_done_counter.update_producer()
|
538
|
-
|
539
|
-
batch_operation = None
|
540
|
-
while self.load_queue.qsize() > 0:
|
541
|
-
op = self.load_queue.get(block=True)
|
542
|
-
if batch_operation is None:
|
543
|
-
batch_operation = op
|
544
|
-
else:
|
545
|
-
batch_operation.merge(op)
|
546
|
-
if batch_operation is None:
|
547
|
-
continue
|
548
|
+
producer_id = self.layer_done_counter.update_producer()
|
549
|
+
op = CacheOperation.merge_ops(self.load_queue)
|
550
|
+
host_indices, device_indices = self.move_indices(op)
|
551
|
+
self.load_queue.clear()
|
552
|
+
producer_event = self.layer_done_counter.events[producer_id]
|
553
|
+
producer_event.start_event.record()
|
548
554
|
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
batch_operation.host_indices, batch_operation.device_indices
|
553
|
-
)
|
554
|
-
for i in range(self.mem_pool_host.layer_num):
|
555
|
+
with torch.cuda.stream(self.load_stream):
|
556
|
+
producer_event.start_event.wait(self.load_stream)
|
557
|
+
for i in range(self.layer_num):
|
555
558
|
self.mem_pool_host.load_to_device_per_layer(
|
556
559
|
self.mem_pool_device,
|
557
560
|
host_indices,
|
@@ -559,13 +562,24 @@ class HiCacheController:
|
|
559
562
|
i,
|
560
563
|
self.io_backend,
|
561
564
|
)
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
565
|
+
producer_event.complete(i)
|
566
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
567
|
+
# NOTE: We must save the host indices and device indices here,
|
568
|
+
# this is because we need to guarantee that these tensors are
|
569
|
+
# still alive when the load stream is executing.
|
570
|
+
if host_indices.is_cuda:
|
571
|
+
host_indices.record_stream(self.load_stream)
|
572
|
+
if device_indices.is_cuda:
|
573
|
+
device_indices.record_stream(self.load_stream)
|
574
|
+
|
575
|
+
self.ack_load_queue.append(
|
576
|
+
HiCacheAck(
|
577
|
+
start_event=producer_event.start_event,
|
578
|
+
finish_event=producer_event.finish_event,
|
579
|
+
node_ids=op.node_ids,
|
580
|
+
)
|
581
|
+
)
|
582
|
+
return producer_id
|
569
583
|
|
570
584
|
def evict_device(
|
571
585
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
@@ -608,7 +622,7 @@ class HiCacheController:
|
|
608
622
|
return operation
|
609
623
|
|
610
624
|
def terminate_prefetch(self, operation):
|
611
|
-
operation.
|
625
|
+
operation.mark_terminate()
|
612
626
|
return operation.completed_tokens, operation.hash_value
|
613
627
|
|
614
628
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
@@ -616,13 +630,19 @@ class HiCacheController:
|
|
616
630
|
for chunk in chunks:
|
617
631
|
self.host_mem_release_queue.put(chunk)
|
618
632
|
|
633
|
+
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
634
|
+
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
635
|
+
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
636
|
+
return hit_page_num
|
637
|
+
|
619
638
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
620
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
639
|
+
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
621
640
|
hash_values, host_indices
|
622
641
|
)
|
623
642
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
624
643
|
if page_data:
|
625
|
-
|
644
|
+
inc = self.page_size * len(hashes) // factor
|
645
|
+
operation.increment(inc)
|
626
646
|
else:
|
627
647
|
logger.warning(
|
628
648
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
@@ -636,7 +656,7 @@ class HiCacheController:
|
|
636
656
|
)
|
637
657
|
get_result = self.storage_backend.batch_get(
|
638
658
|
key_strs,
|
639
|
-
|
659
|
+
target_locations=buffer_ptrs,
|
640
660
|
target_sizes=buffer_sizes,
|
641
661
|
)
|
642
662
|
if get_result != len(hash_values):
|
@@ -647,9 +667,9 @@ class HiCacheController:
|
|
647
667
|
operation.increment(get_result * self.page_size)
|
648
668
|
|
649
669
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
650
|
-
dummy_page_dst = [
|
651
|
-
hash_values
|
652
|
-
|
670
|
+
dummy_page_dst = [
|
671
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
672
|
+
]
|
653
673
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
654
674
|
if page_data is None:
|
655
675
|
return
|
@@ -659,26 +679,16 @@ class HiCacheController:
|
|
659
679
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
660
680
|
)
|
661
681
|
break
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
682
|
+
# Must set the data before increasing the completed tokens.
|
683
|
+
# Otherwise this page may be read before being set.
|
684
|
+
self.mem_pool_host.set_from_flat_data_page(
|
685
|
+
host_indices[i * self.page_size],
|
686
|
+
page_data[i],
|
687
|
+
)
|
688
|
+
if not operation.increment(self.page_size):
|
689
|
+
break # Operation terminated by controller
|
669
690
|
|
670
691
|
def _page_transfer(self, operation):
|
671
|
-
# Select the get function and batch size
|
672
|
-
if self.storage_backend_type == "mooncake":
|
673
|
-
get_func = self._mooncake_page_get
|
674
|
-
elif (
|
675
|
-
self.storage_backend_type == "hf3fs"
|
676
|
-
and self.mem_pool_host.layout == "page_first"
|
677
|
-
):
|
678
|
-
get_func = self._3fs_zero_copy_page_get
|
679
|
-
else:
|
680
|
-
get_func = self._generic_page_get
|
681
|
-
|
682
692
|
# Transfer batch by batch
|
683
693
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
684
694
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -687,12 +697,13 @@ class HiCacheController:
|
|
687
697
|
]
|
688
698
|
prev_completed_tokens = operation.completed_tokens
|
689
699
|
# Get one batch token, and update the completed_tokens if succeed
|
690
|
-
|
700
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
691
701
|
# Check termination
|
692
702
|
if (
|
693
703
|
operation.completed_tokens
|
694
704
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
695
705
|
):
|
706
|
+
operation.mark_terminate()
|
696
707
|
break # Some operations fail or operation terminated by controller
|
697
708
|
# release pre-allocated memory
|
698
709
|
self.append_host_mem_release(
|
@@ -744,7 +755,7 @@ class HiCacheController:
|
|
744
755
|
batch_tokens[i : i + self.page_size], last_hash
|
745
756
|
)
|
746
757
|
batch_hashes.append(last_hash)
|
747
|
-
hit_page_num = self.
|
758
|
+
hit_page_num = self.batch_exists_func(batch_hashes)
|
748
759
|
hash_value.extend(batch_hashes[:hit_page_num])
|
749
760
|
storage_query_count += hit_page_num * self.page_size
|
750
761
|
if hit_page_num < len(batch_hashes):
|
@@ -830,30 +841,20 @@ class HiCacheController:
|
|
830
841
|
)
|
831
842
|
success = self.storage_backend.batch_set(
|
832
843
|
key_strs,
|
833
|
-
|
844
|
+
target_locations=buffer_ptrs,
|
834
845
|
target_sizes=buffer_sizes,
|
835
846
|
)
|
836
847
|
return success
|
837
848
|
|
838
849
|
# zero copy
|
839
850
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
840
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
851
|
+
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
841
852
|
hash_values, host_indices
|
842
853
|
)
|
843
854
|
return self.storage_backend.batch_set(hashes, dsts)
|
844
855
|
|
845
856
|
# Backup batch by batch
|
846
857
|
def _page_backup(self, operation):
|
847
|
-
# Select the set function and batch size
|
848
|
-
if self.storage_backend_type == "mooncake":
|
849
|
-
backup_set_func = self._mooncake_page_set
|
850
|
-
elif (
|
851
|
-
self.storage_backend_type == "hf3fs"
|
852
|
-
and self.mem_pool_host.layout == "page_first"
|
853
|
-
):
|
854
|
-
backup_set_func = self._3fs_zero_copy_page_set
|
855
|
-
else:
|
856
|
-
backup_set_func = self._generic_page_set
|
857
858
|
# Backup batch by batch
|
858
859
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
859
860
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -862,7 +863,7 @@ class HiCacheController:
|
|
862
863
|
]
|
863
864
|
# Set one batch token, and record if success.
|
864
865
|
# todo: allow partial success
|
865
|
-
success =
|
866
|
+
success = self.page_set_func(batch_hashes, batch_host_indices)
|
866
867
|
if not success:
|
867
868
|
logger.warning(
|
868
869
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
@@ -882,7 +883,7 @@ class HiCacheController:
|
|
882
883
|
|
883
884
|
if not self.backup_skip:
|
884
885
|
self._page_backup(operation)
|
885
|
-
self.ack_backup_queue.put(operation
|
886
|
+
self.ack_backup_queue.put(operation)
|
886
887
|
|
887
888
|
except Empty:
|
888
889
|
continue
|