sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -18,53 +18,78 @@ import math
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import TYPE_CHECKING, List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
|
26
|
+
|
25
27
|
if TYPE_CHECKING:
|
26
28
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
27
29
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
28
30
|
|
29
|
-
from sglang.srt.distributed import
|
30
|
-
|
31
|
+
from sglang.srt.distributed import (
|
32
|
+
get_tensor_model_parallel_rank,
|
33
|
+
get_tensor_model_parallel_world_size,
|
34
|
+
)
|
35
|
+
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_dp_rank,
|
37
|
+
get_attention_tp_rank,
|
38
|
+
get_attention_tp_size,
|
39
|
+
is_dp_attention_enabled,
|
40
|
+
)
|
41
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
31
42
|
|
32
43
|
logger = logging.getLogger(__name__)
|
33
44
|
|
34
45
|
|
46
|
+
class LayerLoadingEvent:
|
47
|
+
def __init__(self, num_layers: int):
|
48
|
+
self._num_layers = num_layers
|
49
|
+
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
50
|
+
self.start_event = torch.cuda.Event() # start event on controller stream
|
51
|
+
|
52
|
+
def complete(self, layer_index: int):
|
53
|
+
assert 0 <= layer_index < self._num_layers
|
54
|
+
self.load_events[layer_index].record()
|
55
|
+
|
56
|
+
def wait(self, layer_index: int):
|
57
|
+
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
58
|
+
|
59
|
+
@property
|
60
|
+
def finish_event(self):
|
61
|
+
return self.load_events[-1]
|
62
|
+
|
63
|
+
|
35
64
|
class LayerDoneCounter:
|
36
|
-
def __init__(self, num_layers):
|
65
|
+
def __init__(self, num_layers: int):
|
37
66
|
self.num_layers = num_layers
|
38
67
|
# extra producer and consumer counters for overlap mode
|
39
68
|
self.num_counters = 3
|
40
|
-
self.
|
41
|
-
self.
|
42
|
-
self.
|
43
|
-
self.consumer_index = 0
|
44
|
-
|
45
|
-
def next_producer(self):
|
46
|
-
return (self.producer_index + 1) % self.num_counters
|
69
|
+
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
70
|
+
self.producer_index = -1
|
71
|
+
self.consumer_index = -1
|
47
72
|
|
48
73
|
def update_producer(self):
|
49
|
-
self.producer_index = self.
|
74
|
+
self.producer_index = (self.producer_index + 1) % self.num_counters
|
75
|
+
assert self.events[
|
76
|
+
self.producer_index
|
77
|
+
].finish_event.query(), (
|
78
|
+
"Producer finish event should be ready before being reused."
|
79
|
+
)
|
50
80
|
return self.producer_index
|
51
81
|
|
52
|
-
def set_consumer(self, index):
|
82
|
+
def set_consumer(self, index: int):
|
53
83
|
self.consumer_index = index
|
54
84
|
|
55
|
-
def
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
def wait_until(self, threshold):
|
61
|
-
with self.conditions[self.consumer_index]:
|
62
|
-
while self.counters[self.consumer_index] <= threshold:
|
63
|
-
self.conditions[self.consumer_index].wait()
|
85
|
+
def wait_until(self, threshold: int):
|
86
|
+
if self.consumer_index < 0:
|
87
|
+
return
|
88
|
+
self.events[self.consumer_index].wait(threshold)
|
64
89
|
|
65
90
|
def reset(self):
|
66
|
-
|
67
|
-
|
91
|
+
self.producer_index = -1
|
92
|
+
self.consumer_index = -1
|
68
93
|
|
69
94
|
|
70
95
|
class CacheOperation:
|
@@ -88,36 +113,30 @@ class CacheOperation:
|
|
88
113
|
# default priority is the order of creation
|
89
114
|
self.priority = priority if priority is not None else self.id
|
90
115
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
device_indices=self.device_indices[i : i + chunk_size],
|
110
|
-
node_id=0,
|
111
|
-
)
|
112
|
-
)
|
113
|
-
# Inherit the node_ids on the final chunk
|
114
|
-
if split_ops:
|
115
|
-
split_ops[-1].node_ids = self.node_ids
|
116
|
+
@staticmethod
|
117
|
+
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
118
|
+
assert len(ops) > 0
|
119
|
+
if len(ops) == 1:
|
120
|
+
return ops[0]
|
121
|
+
|
122
|
+
host_indices = torch.cat([op.host_indices for op in ops])
|
123
|
+
device_indices = torch.cat([op.device_indices for op in ops])
|
124
|
+
node_ids = []
|
125
|
+
priority = min(op.priority for op in ops)
|
126
|
+
for op in ops:
|
127
|
+
node_ids.extend(op.node_ids)
|
128
|
+
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
129
|
+
merged_op.node_ids = node_ids
|
130
|
+
return merged_op
|
131
|
+
|
132
|
+
def __lt__(self, other: CacheOperation):
|
133
|
+
return self.priority < other.priority
|
116
134
|
|
117
|
-
return split_ops
|
118
135
|
|
119
|
-
|
120
|
-
|
136
|
+
class HiCacheAck(NamedTuple):
|
137
|
+
start_event: torch.cuda.Event
|
138
|
+
finish_event: torch.cuda.Event
|
139
|
+
node_ids: List[int]
|
121
140
|
|
122
141
|
|
123
142
|
class TransferBuffer:
|
@@ -196,26 +215,25 @@ class PrefetchOperation(StorageOperation):
|
|
196
215
|
):
|
197
216
|
self.request_id = request_id
|
198
217
|
|
199
|
-
self._done_flag = False
|
200
218
|
self._lock = threading.Lock()
|
201
|
-
|
219
|
+
self._terminated_flag = False
|
202
220
|
self.start_time = time.monotonic()
|
203
221
|
|
204
222
|
super().__init__(host_indices, token_ids, last_hash)
|
205
223
|
|
206
224
|
def increment(self, num_tokens: int):
|
207
225
|
with self._lock:
|
208
|
-
if self.
|
226
|
+
if self._terminated_flag:
|
209
227
|
return False
|
210
228
|
self.completed_tokens += num_tokens
|
211
229
|
return True
|
212
230
|
|
213
|
-
def
|
231
|
+
def mark_terminate(self):
|
214
232
|
with self._lock:
|
215
|
-
self.
|
233
|
+
self._terminated_flag = True
|
216
234
|
|
217
|
-
def
|
218
|
-
return self.
|
235
|
+
def is_terminated(self) -> bool:
|
236
|
+
return self._terminated_flag
|
219
237
|
|
220
238
|
|
221
239
|
class HiCacheController:
|
@@ -226,11 +244,13 @@ class HiCacheController:
|
|
226
244
|
mem_pool_host: HostKVCache,
|
227
245
|
page_size: int,
|
228
246
|
tp_group: torch.distributed.ProcessGroup,
|
229
|
-
load_cache_event: threading.Event
|
247
|
+
load_cache_event: threading.Event,
|
230
248
|
write_policy: str = "write_through_selective",
|
231
249
|
io_backend: str = "",
|
232
250
|
storage_backend: Optional[str] = None,
|
233
251
|
prefetch_threshold: int = 256,
|
252
|
+
model_name: Optional[str] = None,
|
253
|
+
storage_backend_extra_config: Optional[str] = None,
|
234
254
|
):
|
235
255
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
236
256
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -238,30 +258,37 @@ class HiCacheController:
|
|
238
258
|
self.write_policy = write_policy
|
239
259
|
self.page_size = page_size
|
240
260
|
self.io_backend = io_backend
|
241
|
-
|
242
261
|
self.enable_storage = False
|
243
|
-
|
244
|
-
# todo: move backend initialization to storage backend module
|
262
|
+
|
245
263
|
if storage_backend is not None:
|
246
264
|
self.storage_backend_type = storage_backend
|
247
|
-
from sglang.srt.mem_cache.hicache_storage import
|
265
|
+
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
266
|
+
|
267
|
+
self.get_hash_str = get_hash_str
|
268
|
+
self.storage_config = self._generate_storage_config(
|
269
|
+
model_name, storage_backend_extra_config
|
270
|
+
)
|
271
|
+
# for MLA models, only one rank needs to backup the KV cache
|
272
|
+
self.backup_skip = (
|
273
|
+
self.storage_config.is_mla_model
|
274
|
+
# todo: load balancing
|
275
|
+
and self.storage_config.tp_rank != 0
|
276
|
+
)
|
248
277
|
|
249
278
|
if storage_backend == "file":
|
250
|
-
|
251
|
-
|
279
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
|
280
|
+
|
281
|
+
self.storage_backend = HiCacheFile(self.storage_config)
|
252
282
|
elif storage_backend == "nixl":
|
253
283
|
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
254
284
|
|
255
285
|
self.storage_backend = HiCacheNixl()
|
256
|
-
self.get_hash_str = get_hash_str
|
257
286
|
elif storage_backend == "mooncake":
|
258
287
|
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
259
288
|
MooncakeStore,
|
260
|
-
get_hash_str_mooncake,
|
261
289
|
)
|
262
290
|
|
263
|
-
self.storage_backend = MooncakeStore(
|
264
|
-
self.get_hash_str = get_hash_str_mooncake
|
291
|
+
self.storage_backend = MooncakeStore(self.storage_config)
|
265
292
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
266
293
|
assert self.mem_pool_host.layout == "page_first"
|
267
294
|
elif storage_backend == "hf3fs":
|
@@ -279,19 +306,21 @@ class HiCacheController:
|
|
279
306
|
)
|
280
307
|
dtype = mem_pool_host.dtype
|
281
308
|
self.storage_backend = HiCacheHF3FS.from_env_config(
|
282
|
-
bytes_per_page, dtype
|
309
|
+
bytes_per_page, dtype, self.storage_config
|
283
310
|
)
|
284
|
-
self.get_hash_str = get_hash_str
|
285
311
|
else:
|
286
312
|
raise NotImplementedError(
|
287
313
|
f"Unsupported storage backend: {storage_backend}"
|
288
314
|
)
|
315
|
+
|
289
316
|
self.enable_storage = True
|
290
317
|
# todo: threshold policy for prefetching
|
291
318
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
292
319
|
self.prefetch_capacity_limit = int(
|
293
320
|
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
294
321
|
)
|
322
|
+
# granularity of batch storage IO operations, in number of pages
|
323
|
+
self.storage_batch_size = 128
|
295
324
|
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
296
325
|
self.prefetch_tokens_occupied = 0
|
297
326
|
|
@@ -302,15 +331,26 @@ class HiCacheController:
|
|
302
331
|
self.prefetch_tp_group = torch.distributed.new_group(
|
303
332
|
group_ranks, backend="gloo"
|
304
333
|
)
|
305
|
-
self.prefetch_io_tp_group = torch.distributed.new_group(
|
306
|
-
group_ranks, backend="gloo"
|
307
|
-
)
|
308
|
-
self.backup_tp_group = torch.distributed.new_group(
|
309
|
-
group_ranks, backend="gloo"
|
310
|
-
)
|
311
334
|
|
312
|
-
|
313
|
-
|
335
|
+
# Select the get and set functions
|
336
|
+
self.page_get_func = self._generic_page_get
|
337
|
+
self.page_set_func = self._generic_page_set
|
338
|
+
self.batch_exists_func = self.storage_backend.batch_exists
|
339
|
+
self.is_3fs_zerocopy = (
|
340
|
+
self.storage_backend_type == "hf3fs"
|
341
|
+
and self.mem_pool_host.layout == "page_first"
|
342
|
+
)
|
343
|
+
if self.storage_backend_type == "mooncake":
|
344
|
+
self.page_get_func = self._mooncake_page_get
|
345
|
+
self.page_set_func = self._mooncake_page_set
|
346
|
+
elif self.is_3fs_zerocopy:
|
347
|
+
self.page_get_func = self._3fs_zero_copy_page_get
|
348
|
+
self.page_set_func = self._3fs_zero_copy_page_set
|
349
|
+
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
350
|
+
|
351
|
+
self.device = self.mem_pool_device.device
|
352
|
+
self.layer_num = self.mem_pool_device.layer_num
|
353
|
+
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
314
354
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
315
355
|
|
316
356
|
if write_policy not in [
|
@@ -320,11 +360,11 @@ class HiCacheController:
|
|
320
360
|
]:
|
321
361
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
322
362
|
|
323
|
-
self.write_queue = PriorityQueue()
|
324
|
-
self.load_queue =
|
325
|
-
|
326
|
-
self.
|
327
|
-
self.
|
363
|
+
# self.write_queue = PriorityQueue[CacheOperation]()
|
364
|
+
self.load_queue: List[CacheOperation] = []
|
365
|
+
self.write_queue: List[CacheOperation] = []
|
366
|
+
self.ack_load_queue: List[HiCacheAck] = []
|
367
|
+
self.ack_write_queue: List[HiCacheAck] = []
|
328
368
|
|
329
369
|
self.stop_event = threading.Event()
|
330
370
|
self.write_buffer = TransferBuffer(self.stop_event)
|
@@ -335,16 +375,6 @@ class HiCacheController:
|
|
335
375
|
self.write_stream = torch.cuda.Stream()
|
336
376
|
self.load_stream = torch.cuda.Stream()
|
337
377
|
|
338
|
-
self.write_thread = threading.Thread(
|
339
|
-
target=self.write_thread_func_direct, daemon=True
|
340
|
-
)
|
341
|
-
self.load_thread = threading.Thread(
|
342
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
343
|
-
)
|
344
|
-
|
345
|
-
self.write_thread.start()
|
346
|
-
self.load_thread.start()
|
347
|
-
|
348
378
|
if self.enable_storage:
|
349
379
|
self.prefetch_thread = threading.Thread(
|
350
380
|
target=self.prefetch_thread_func, daemon=True
|
@@ -357,21 +387,57 @@ class HiCacheController:
|
|
357
387
|
|
358
388
|
self.prefetch_revoke_queue = Queue()
|
359
389
|
self.ack_backup_queue = Queue()
|
390
|
+
self.host_mem_release_queue = Queue()
|
360
391
|
|
361
392
|
self.prefetch_thread.start()
|
362
393
|
self.backup_thread.start()
|
363
394
|
|
395
|
+
def _generate_storage_config(
|
396
|
+
self,
|
397
|
+
model_name: Optional[str] = None,
|
398
|
+
storage_backend_extra_config: Optional[str] = None,
|
399
|
+
):
|
400
|
+
|
401
|
+
if is_dp_attention_enabled():
|
402
|
+
self.tp_rank = get_attention_tp_rank()
|
403
|
+
self.tp_size = get_attention_tp_size()
|
404
|
+
self.dp_rank = get_attention_dp_rank()
|
405
|
+
else:
|
406
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
407
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
408
|
+
self.dp_rank = 0
|
409
|
+
|
410
|
+
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
411
|
+
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
412
|
+
|
413
|
+
# Parse extra config JSON if provided
|
414
|
+
extra_config = None
|
415
|
+
if storage_backend_extra_config:
|
416
|
+
try:
|
417
|
+
import json
|
418
|
+
|
419
|
+
extra_config = json.loads(storage_backend_extra_config)
|
420
|
+
except Exception as e:
|
421
|
+
logger.error(f"Invalid backend extra config JSON: {e}")
|
422
|
+
|
423
|
+
return HiCacheStorageConfig(
|
424
|
+
tp_rank=self.tp_rank,
|
425
|
+
tp_size=self.tp_size,
|
426
|
+
is_mla_model=is_mla_backend,
|
427
|
+
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
428
|
+
model_name=model_name,
|
429
|
+
extra_config=extra_config,
|
430
|
+
)
|
431
|
+
|
364
432
|
def reset(self):
|
365
433
|
self.stop_event.set()
|
366
|
-
self.write_thread.join()
|
367
|
-
self.load_thread.join()
|
368
434
|
|
369
|
-
self.write_queue.
|
370
|
-
self.load_queue.
|
435
|
+
self.write_queue.clear()
|
436
|
+
self.load_queue.clear()
|
371
437
|
self.write_buffer.clear()
|
372
438
|
self.load_buffer.clear()
|
373
|
-
self.ack_write_queue.
|
374
|
-
self.ack_load_queue.
|
439
|
+
self.ack_write_queue.clear()
|
440
|
+
self.ack_load_queue.clear()
|
375
441
|
if self.enable_storage:
|
376
442
|
self.prefetch_thread.join()
|
377
443
|
self.backup_thread.join()
|
@@ -380,15 +446,7 @@ class HiCacheController:
|
|
380
446
|
self.prefetch_revoke_queue.queue.clear()
|
381
447
|
self.ack_backup_queue.queue.clear()
|
382
448
|
|
383
|
-
self.write_thread = threading.Thread(
|
384
|
-
target=self.write_thread_func_direct, daemon=True
|
385
|
-
)
|
386
|
-
self.load_thread = threading.Thread(
|
387
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
388
|
-
)
|
389
449
|
self.stop_event.clear()
|
390
|
-
self.write_thread.start()
|
391
|
-
self.load_thread.start()
|
392
450
|
|
393
451
|
if self.enable_storage:
|
394
452
|
self.prefetch_thread = threading.Thread(
|
@@ -400,20 +458,11 @@ class HiCacheController:
|
|
400
458
|
self.prefetch_thread.start()
|
401
459
|
self.backup_thread.start()
|
402
460
|
|
403
|
-
@property
|
404
|
-
def backup_skip(self):
|
405
|
-
return (
|
406
|
-
self.is_mla
|
407
|
-
and get_tensor_model_parallel_rank() != 0
|
408
|
-
# todo: only support file and mooncake
|
409
|
-
and self.storage_backend_type in ["file", "mooncake"]
|
410
|
-
)
|
411
|
-
|
412
461
|
def write(
|
413
462
|
self,
|
414
463
|
device_indices: torch.Tensor,
|
415
464
|
priority: Optional[int] = None,
|
416
|
-
node_id: int =
|
465
|
+
node_id: int = -1,
|
417
466
|
) -> Optional[torch.Tensor]:
|
418
467
|
"""
|
419
468
|
Back up KV caches from device memory to host memory.
|
@@ -422,17 +471,46 @@ class HiCacheController:
|
|
422
471
|
if host_indices is None:
|
423
472
|
return None
|
424
473
|
self.mem_pool_host.protect_write(host_indices)
|
425
|
-
|
426
|
-
self.write_queue.put(
|
474
|
+
self.write_queue.append(
|
427
475
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
428
476
|
)
|
477
|
+
self.start_writing()
|
429
478
|
return host_indices
|
430
479
|
|
480
|
+
def start_writing(self) -> None:
|
481
|
+
if len(self.write_queue) == 0:
|
482
|
+
return
|
483
|
+
|
484
|
+
op = CacheOperation.merge_ops(self.write_queue)
|
485
|
+
host_indices, device_indices = self.move_indices(op)
|
486
|
+
self.write_queue.clear()
|
487
|
+
|
488
|
+
start_event = torch.cuda.Event()
|
489
|
+
finish_event = torch.cuda.Event()
|
490
|
+
|
491
|
+
start_event.record()
|
492
|
+
with torch.cuda.stream(self.write_stream):
|
493
|
+
start_event.wait(self.write_stream)
|
494
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
495
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
496
|
+
)
|
497
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
498
|
+
finish_event.record()
|
499
|
+
# NOTE: We must save the host indices and device indices here,
|
500
|
+
# this is because we need to guarantee that these tensors are
|
501
|
+
# still alive when the write stream is executing.
|
502
|
+
if host_indices.is_cuda:
|
503
|
+
host_indices.record_stream(self.write_stream)
|
504
|
+
if device_indices.is_cuda:
|
505
|
+
device_indices.record_stream(self.write_stream)
|
506
|
+
|
507
|
+
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
508
|
+
|
431
509
|
def load(
|
432
510
|
self,
|
433
511
|
host_indices: torch.Tensor,
|
434
512
|
priority: Optional[int] = None,
|
435
|
-
node_id: int =
|
513
|
+
node_id: int = -1,
|
436
514
|
) -> Optional[torch.Tensor]:
|
437
515
|
"""
|
438
516
|
Load KV caches from host memory to device memory.
|
@@ -441,17 +519,18 @@ class HiCacheController:
|
|
441
519
|
if device_indices is None:
|
442
520
|
return None
|
443
521
|
self.mem_pool_host.protect_load(host_indices)
|
444
|
-
|
445
|
-
torch.cuda.current_stream().synchronize()
|
446
|
-
self.load_queue.put(
|
522
|
+
self.load_queue.append(
|
447
523
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
448
524
|
)
|
449
525
|
return device_indices
|
450
526
|
|
451
|
-
def move_indices(self,
|
527
|
+
def move_indices(self, op: CacheOperation):
|
528
|
+
host_indices, device_indices = op.host_indices, op.device_indices
|
452
529
|
# move indices to GPU if using kernels, to host if using direct indexing
|
453
530
|
if self.io_backend == "kernel":
|
454
|
-
|
531
|
+
if not host_indices.is_cuda:
|
532
|
+
host_indices = host_indices.to(self.device, non_blocking=True)
|
533
|
+
return host_indices, device_indices
|
455
534
|
elif self.io_backend == "direct":
|
456
535
|
device_indices = device_indices.cpu()
|
457
536
|
host_indices, idx = host_indices.sort()
|
@@ -459,58 +538,20 @@ class HiCacheController:
|
|
459
538
|
else:
|
460
539
|
raise ValueError(f"Unsupported io backend")
|
461
540
|
|
462
|
-
def
|
463
|
-
|
464
|
-
|
465
|
-
"""
|
466
|
-
torch.cuda.set_stream(self.write_stream)
|
467
|
-
while not self.stop_event.is_set():
|
468
|
-
try:
|
469
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
470
|
-
host_indices, device_indices = self.move_indices(
|
471
|
-
operation.host_indices, operation.device_indices
|
472
|
-
)
|
473
|
-
self.mem_pool_host.backup_from_device_all_layer(
|
474
|
-
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
475
|
-
)
|
476
|
-
self.write_stream.synchronize()
|
477
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
478
|
-
for node_id in operation.node_ids:
|
479
|
-
if node_id != 0:
|
480
|
-
self.ack_write_queue.put(node_id)
|
481
|
-
except Empty:
|
482
|
-
continue
|
483
|
-
except Exception as e:
|
484
|
-
logger.error(e)
|
541
|
+
def start_loading(self) -> int:
|
542
|
+
if len(self.load_queue) == 0:
|
543
|
+
return -1
|
485
544
|
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
self.load_cache_event.wait(timeout=1)
|
493
|
-
if not self.load_cache_event.is_set():
|
494
|
-
continue
|
495
|
-
self.load_cache_event.clear()
|
496
|
-
self.layer_done_counter.update_producer()
|
497
|
-
|
498
|
-
batch_operation = None
|
499
|
-
while self.load_queue.qsize() > 0:
|
500
|
-
op = self.load_queue.get(block=True)
|
501
|
-
if batch_operation is None:
|
502
|
-
batch_operation = op
|
503
|
-
else:
|
504
|
-
batch_operation.merge(op)
|
505
|
-
if batch_operation is None:
|
506
|
-
continue
|
545
|
+
producer_id = self.layer_done_counter.update_producer()
|
546
|
+
op = CacheOperation.merge_ops(self.load_queue)
|
547
|
+
host_indices, device_indices = self.move_indices(op)
|
548
|
+
self.load_queue.clear()
|
549
|
+
producer_event = self.layer_done_counter.events[producer_id]
|
550
|
+
producer_event.start_event.record()
|
507
551
|
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
batch_operation.host_indices, batch_operation.device_indices
|
512
|
-
)
|
513
|
-
for i in range(self.mem_pool_host.layer_num):
|
552
|
+
with torch.cuda.stream(self.load_stream):
|
553
|
+
producer_event.start_event.wait(self.load_stream)
|
554
|
+
for i in range(self.layer_num):
|
514
555
|
self.mem_pool_host.load_to_device_per_layer(
|
515
556
|
self.mem_pool_device,
|
516
557
|
host_indices,
|
@@ -518,13 +559,24 @@ class HiCacheController:
|
|
518
559
|
i,
|
519
560
|
self.io_backend,
|
520
561
|
)
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
|
526
|
-
|
527
|
-
|
562
|
+
producer_event.complete(i)
|
563
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
564
|
+
# NOTE: We must save the host indices and device indices here,
|
565
|
+
# this is because we need to guarantee that these tensors are
|
566
|
+
# still alive when the load stream is executing.
|
567
|
+
if host_indices.is_cuda:
|
568
|
+
host_indices.record_stream(self.load_stream)
|
569
|
+
if device_indices.is_cuda:
|
570
|
+
device_indices.record_stream(self.load_stream)
|
571
|
+
|
572
|
+
self.ack_load_queue.append(
|
573
|
+
HiCacheAck(
|
574
|
+
start_event=producer_event.start_event,
|
575
|
+
finish_event=producer_event.finish_event,
|
576
|
+
node_ids=op.node_ids,
|
577
|
+
)
|
578
|
+
)
|
579
|
+
return producer_id
|
528
580
|
|
529
581
|
def evict_device(
|
530
582
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
@@ -567,63 +619,93 @@ class HiCacheController:
|
|
567
619
|
return operation
|
568
620
|
|
569
621
|
def terminate_prefetch(self, operation):
|
570
|
-
operation.
|
622
|
+
operation.mark_terminate()
|
571
623
|
return operation.completed_tokens, operation.hash_value
|
572
624
|
|
573
|
-
def
|
574
|
-
|
575
|
-
|
625
|
+
def append_host_mem_release(self, host_indices: torch.Tensor):
|
626
|
+
chunks = host_indices.split(self.mem_pool_host.page_size)
|
627
|
+
for chunk in chunks:
|
628
|
+
self.host_mem_release_queue.put(chunk)
|
629
|
+
|
630
|
+
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
631
|
+
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
632
|
+
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
633
|
+
return hit_page_num
|
634
|
+
|
635
|
+
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
636
|
+
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
637
|
+
hash_values, host_indices
|
576
638
|
)
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
break
|
586
|
-
completed_tokens = operation.completed_tokens
|
587
|
-
if operation.increment(self.page_size * len(page_hashes)):
|
588
|
-
for i in range(len(page_hashes)):
|
589
|
-
completed_tokens += self.page_size
|
590
|
-
else:
|
591
|
-
break
|
639
|
+
page_data = self.storage_backend.batch_get(hashes, dsts)
|
640
|
+
if page_data:
|
641
|
+
inc = self.page_size * len(hashes) // factor
|
642
|
+
operation.increment(inc)
|
643
|
+
else:
|
644
|
+
logger.warning(
|
645
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
646
|
+
)
|
592
647
|
|
593
|
-
def
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
648
|
+
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
649
|
+
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
650
|
+
hash_values,
|
651
|
+
host_indices,
|
652
|
+
self.storage_config.tp_rank,
|
653
|
+
)
|
654
|
+
get_result = self.storage_backend.batch_get(
|
655
|
+
key_strs,
|
656
|
+
target_locations=buffer_ptrs,
|
657
|
+
target_sizes=buffer_sizes,
|
658
|
+
)
|
659
|
+
if get_result != len(hash_values):
|
660
|
+
logger.warning(
|
661
|
+
f"Prefetch operation {operation.request_id} failed or partially failed."
|
662
|
+
)
|
663
|
+
if get_result != 0:
|
664
|
+
operation.increment(get_result * self.page_size)
|
665
|
+
|
666
|
+
def _generic_page_get(self, operation, hash_values, host_indices):
|
667
|
+
dummy_page_dst = [
|
668
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
669
|
+
]
|
670
|
+
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
671
|
+
if page_data is None:
|
672
|
+
return
|
673
|
+
for i in range(len(hash_values)):
|
674
|
+
if page_data[i] is None:
|
603
675
|
logger.warning(
|
604
|
-
f"Prefetch operation {operation.request_id} failed to retrieve page {
|
676
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
605
677
|
)
|
606
678
|
break
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
679
|
+
# Must set the data before increasing the completed tokens.
|
680
|
+
# Otherwise this page may be read before being set.
|
681
|
+
self.mem_pool_host.set_from_flat_data_page(
|
682
|
+
host_indices[i * self.page_size],
|
683
|
+
page_data[i],
|
684
|
+
)
|
685
|
+
if not operation.increment(self.page_size):
|
686
|
+
break # Operation terminated by controller
|
687
|
+
|
688
|
+
def _page_transfer(self, operation):
|
689
|
+
# Transfer batch by batch
|
690
|
+
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
691
|
+
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
692
|
+
batch_host_indices = operation.host_indices[
|
693
|
+
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
694
|
+
]
|
695
|
+
prev_completed_tokens = operation.completed_tokens
|
696
|
+
# Get one batch token, and update the completed_tokens if succeed
|
697
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
698
|
+
# Check termination
|
699
|
+
if (
|
700
|
+
operation.completed_tokens
|
701
|
+
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
702
|
+
):
|
703
|
+
operation.mark_terminate()
|
704
|
+
break # Some operations fail or operation terminated by controller
|
705
|
+
# release pre-allocated memory
|
706
|
+
self.append_host_mem_release(
|
707
|
+
operation.host_indices[operation.completed_tokens :]
|
621
708
|
)
|
622
|
-
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
623
|
-
operation.increment(len(operation.hash_value) * self.page_size)
|
624
|
-
|
625
|
-
def is_mooncake_backend(self):
|
626
|
-
return self.storage_backend_type == "mooncake"
|
627
709
|
|
628
710
|
def prefetch_io_aux_func(self):
|
629
711
|
"""
|
@@ -632,35 +714,50 @@ class HiCacheController:
|
|
632
714
|
while not self.stop_event.is_set():
|
633
715
|
try:
|
634
716
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
635
|
-
|
636
|
-
self.mooncake_page_transfer(operation)
|
637
|
-
elif self.storage_backend_type == "hf3fs":
|
638
|
-
if self.mem_pool_host.layout == "page_first":
|
639
|
-
self.zerocopy_page_transfer(operation, batch_size=128)
|
640
|
-
elif self.mem_pool_host.layout == "layer_first":
|
641
|
-
self.generic_page_transfer(operation, batch_size=128)
|
642
|
-
else:
|
643
|
-
self.generic_page_transfer(operation)
|
644
|
-
|
645
|
-
if self.tp_world_size > 1:
|
646
|
-
# to ensure all TP workers release the host memory at the same time
|
647
|
-
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
717
|
+
self._page_transfer(operation)
|
648
718
|
# operation terminated by controller, release pre-allocated memory
|
649
|
-
self.
|
719
|
+
self.append_host_mem_release(
|
650
720
|
operation.host_indices[operation.completed_tokens :]
|
651
721
|
)
|
652
722
|
except Empty:
|
653
723
|
continue
|
654
724
|
|
655
|
-
def
|
725
|
+
def prefetch_rate_limited(self) -> bool:
|
656
726
|
"""
|
657
727
|
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
658
728
|
"""
|
659
729
|
# cancel prefetch if too much memory is occupied
|
660
730
|
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
661
|
-
return
|
731
|
+
return True
|
662
732
|
# todo: more sophisticated rate limiting based on storage backend performance
|
663
|
-
return
|
733
|
+
return False
|
734
|
+
|
735
|
+
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
736
|
+
last_hash = operation.last_hash
|
737
|
+
tokens_to_fetch = operation.token_ids
|
738
|
+
|
739
|
+
storage_query_count = 0
|
740
|
+
hash_value = []
|
741
|
+
|
742
|
+
for start in range(
|
743
|
+
0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
|
744
|
+
):
|
745
|
+
end = min(
|
746
|
+
start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
|
747
|
+
)
|
748
|
+
batch_tokens = tokens_to_fetch[start:end]
|
749
|
+
batch_hashes = []
|
750
|
+
for i in range(0, len(batch_tokens), self.page_size):
|
751
|
+
last_hash = self.get_hash_str(
|
752
|
+
batch_tokens[i : i + self.page_size], last_hash
|
753
|
+
)
|
754
|
+
batch_hashes.append(last_hash)
|
755
|
+
hit_page_num = self.batch_exists_func(batch_hashes)
|
756
|
+
hash_value.extend(batch_hashes[:hit_page_num])
|
757
|
+
storage_query_count += hit_page_num * self.page_size
|
758
|
+
if hit_page_num < len(batch_hashes):
|
759
|
+
break
|
760
|
+
return hash_value, storage_query_count
|
664
761
|
|
665
762
|
def prefetch_thread_func(self):
|
666
763
|
"""
|
@@ -675,39 +772,7 @@ class HiCacheController:
|
|
675
772
|
if operation is None:
|
676
773
|
continue
|
677
774
|
|
678
|
-
storage_hit_count =
|
679
|
-
if (
|
680
|
-
operation.host_indices is not None
|
681
|
-
) and self.prefetch_rate_limit_check():
|
682
|
-
last_hash = operation.last_hash
|
683
|
-
tokens_to_fetch = operation.token_ids
|
684
|
-
|
685
|
-
remaining_tokens = len(tokens_to_fetch)
|
686
|
-
hash_value = []
|
687
|
-
while remaining_tokens >= self.page_size:
|
688
|
-
last_hash = self.get_hash_str(
|
689
|
-
tokens_to_fetch[
|
690
|
-
storage_hit_count : storage_hit_count + self.page_size
|
691
|
-
],
|
692
|
-
last_hash,
|
693
|
-
)
|
694
|
-
|
695
|
-
# todo, more unified interface
|
696
|
-
if not self.is_mooncake_backend():
|
697
|
-
if not self.storage_backend.exists(last_hash):
|
698
|
-
break
|
699
|
-
hash_value.append(last_hash)
|
700
|
-
storage_hit_count += self.page_size
|
701
|
-
remaining_tokens -= self.page_size
|
702
|
-
|
703
|
-
if self.is_mooncake_backend():
|
704
|
-
# deferring to batch exists for mooncake store
|
705
|
-
exist_result = self.storage_backend.exists(hash_value)
|
706
|
-
storage_hit_count = (
|
707
|
-
sum(1 for v in exist_result.values() if v != 0)
|
708
|
-
* self.page_size
|
709
|
-
)
|
710
|
-
|
775
|
+
hash_value, storage_hit_count = self._storage_hit_query(operation)
|
711
776
|
if self.tp_world_size > 1:
|
712
777
|
storage_hit_count_tensor = torch.tensor(
|
713
778
|
storage_hit_count, dtype=torch.int
|
@@ -722,8 +787,7 @@ class HiCacheController:
|
|
722
787
|
if storage_hit_count < self.prefetch_threshold:
|
723
788
|
# not to prefetch if not enough benefits
|
724
789
|
self.prefetch_revoke_queue.put(operation.request_id)
|
725
|
-
|
726
|
-
self.mem_pool_host.free(operation.host_indices)
|
790
|
+
self.append_host_mem_release(operation.host_indices)
|
727
791
|
logger.debug(
|
728
792
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
729
793
|
)
|
@@ -732,7 +796,9 @@ class HiCacheController:
|
|
732
796
|
: (storage_hit_count // self.page_size)
|
733
797
|
]
|
734
798
|
# free the pre-allocated memory for pages that are not hit
|
735
|
-
self.
|
799
|
+
self.append_host_mem_release(
|
800
|
+
operation.host_indices[storage_hit_count:]
|
801
|
+
)
|
736
802
|
operation.host_indices = operation.host_indices[:storage_hit_count]
|
737
803
|
logger.debug(
|
738
804
|
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
@@ -755,59 +821,52 @@ class HiCacheController:
|
|
755
821
|
self.backup_queue.put(operation)
|
756
822
|
return operation.id
|
757
823
|
|
758
|
-
|
759
|
-
|
760
|
-
|
824
|
+
# non-zero copy
|
825
|
+
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
826
|
+
data = [
|
827
|
+
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
|
828
|
+
for i in range(len(hash_values))
|
829
|
+
]
|
830
|
+
return self.storage_backend.batch_set(hash_values, data)
|
831
|
+
|
832
|
+
# zero copy
|
833
|
+
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
|
834
|
+
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
835
|
+
hash_values,
|
836
|
+
host_indices,
|
837
|
+
self.storage_config.tp_rank,
|
761
838
|
)
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
778
|
-
|
839
|
+
success = self.storage_backend.batch_set(
|
840
|
+
key_strs,
|
841
|
+
target_locations=buffer_ptrs,
|
842
|
+
target_sizes=buffer_sizes,
|
843
|
+
)
|
844
|
+
return success
|
845
|
+
|
846
|
+
# zero copy
|
847
|
+
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
848
|
+
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
849
|
+
hash_values, host_indices
|
850
|
+
)
|
851
|
+
return self.storage_backend.batch_set(hashes, dsts)
|
852
|
+
|
853
|
+
# Backup batch by batch
|
854
|
+
def _page_backup(self, operation):
|
855
|
+
# Backup batch by batch
|
856
|
+
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
857
|
+
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
858
|
+
batch_host_indices = operation.host_indices[
|
859
|
+
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
779
860
|
]
|
780
|
-
|
861
|
+
# Set one batch token, and record if success.
|
862
|
+
# todo: allow partial success
|
863
|
+
success = self.page_set_func(batch_hashes, batch_host_indices)
|
781
864
|
if not success:
|
782
|
-
logger.warning(
|
783
|
-
|
784
|
-
operation.completed_tokens += self.page_size * len(page_hashes)
|
785
|
-
|
786
|
-
def mooncake_page_backup(self, operation):
|
787
|
-
if len(operation.hash_value):
|
788
|
-
exist_hashvalues = self.storage_backend.exists(operation.hash_value)
|
789
|
-
indices = operation.host_indices.tolist()
|
790
|
-
non_exist_keys = []
|
791
|
-
non_exist_indices = []
|
792
|
-
for i in range(len(operation.hash_value)):
|
793
|
-
if not exist_hashvalues[operation.hash_value[i]]:
|
794
|
-
non_exist_keys.append(operation.hash_value[i])
|
795
|
-
non_exist_indices.extend(
|
796
|
-
indices[i * self.page_size : (i + 1) * self.page_size]
|
797
|
-
)
|
798
|
-
if len(non_exist_keys) > 0:
|
799
|
-
key_strs, buffer_ptrs, buffer_sizes = (
|
800
|
-
self.mem_pool_host.get_buffer_meta(
|
801
|
-
non_exist_keys, non_exist_indices
|
802
|
-
)
|
803
|
-
)
|
804
|
-
# TODO: check the return value of batch set to see how many tokens are set successfully
|
805
|
-
self.storage_backend.batch_set(
|
806
|
-
key_strs,
|
807
|
-
target_location=buffer_ptrs,
|
808
|
-
target_sizes=buffer_sizes,
|
865
|
+
logger.warning(
|
866
|
+
f"Write page to storage: {len(batch_hashes)} pages failed."
|
809
867
|
)
|
810
|
-
|
868
|
+
break
|
869
|
+
operation.completed_tokens += self.page_size * len(batch_hashes)
|
811
870
|
|
812
871
|
def backup_thread_func(self):
|
813
872
|
"""
|
@@ -820,36 +879,8 @@ class HiCacheController:
|
|
820
879
|
continue
|
821
880
|
|
822
881
|
if not self.backup_skip:
|
823
|
-
|
824
|
-
|
825
|
-
elif self.storage_backend_type == "hf3fs":
|
826
|
-
if self.mem_pool_host.layout == "page_first":
|
827
|
-
self.zerocopy_page_backup(operation, batch_size=128)
|
828
|
-
elif self.mem_pool_host.layout == "layer_first":
|
829
|
-
self.generic_page_backup(operation, batch_size=128)
|
830
|
-
else:
|
831
|
-
self.generic_page_backup(operation)
|
832
|
-
min_completed_tokens = operation.completed_tokens
|
833
|
-
else:
|
834
|
-
min_completed_tokens = len(operation.token_ids)
|
835
|
-
|
836
|
-
if self.tp_world_size > 1:
|
837
|
-
completed_tokens_tensor = torch.tensor(
|
838
|
-
min_completed_tokens, dtype=torch.int
|
839
|
-
)
|
840
|
-
torch.distributed.all_reduce(
|
841
|
-
completed_tokens_tensor,
|
842
|
-
op=torch.distributed.ReduceOp.MIN,
|
843
|
-
group=self.backup_tp_group,
|
844
|
-
)
|
845
|
-
min_completed_tokens = completed_tokens_tensor.item()
|
846
|
-
|
847
|
-
self.ack_backup_queue.put(
|
848
|
-
(
|
849
|
-
operation.id,
|
850
|
-
min_completed_tokens,
|
851
|
-
)
|
852
|
-
)
|
882
|
+
self._page_backup(operation)
|
883
|
+
self.ack_backup_queue.put(operation)
|
853
884
|
|
854
885
|
except Empty:
|
855
886
|
continue
|