sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import TYPE_CHECKING, List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
|
|
33
33
|
get_tensor_model_parallel_world_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_dp_rank,
|
36
37
|
get_attention_tp_rank,
|
37
38
|
get_attention_tp_size,
|
38
39
|
is_dp_attention_enabled,
|
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class LayerLoadingEvent:
|
47
|
+
def __init__(self, num_layers: int):
|
48
|
+
self._num_layers = num_layers
|
49
|
+
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
50
|
+
self.start_event = torch.cuda.Event() # start event on controller stream
|
51
|
+
|
52
|
+
def complete(self, layer_index: int):
|
53
|
+
assert 0 <= layer_index < self._num_layers
|
54
|
+
self.load_events[layer_index].record()
|
55
|
+
|
56
|
+
def wait(self, layer_index: int):
|
57
|
+
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
58
|
+
|
59
|
+
@property
|
60
|
+
def finish_event(self):
|
61
|
+
return self.load_events[-1]
|
62
|
+
|
63
|
+
|
45
64
|
class LayerDoneCounter:
|
46
|
-
def __init__(self, num_layers):
|
65
|
+
def __init__(self, num_layers: int):
|
47
66
|
self.num_layers = num_layers
|
48
67
|
# extra producer and consumer counters for overlap mode
|
49
68
|
self.num_counters = 3
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
self.consumer_index = 0
|
54
|
-
|
55
|
-
def next_producer(self):
|
56
|
-
return (self.producer_index + 1) % self.num_counters
|
69
|
+
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
70
|
+
self.producer_index = -1
|
71
|
+
self.consumer_index = -1
|
57
72
|
|
58
73
|
def update_producer(self):
|
59
|
-
self.producer_index = self.
|
74
|
+
self.producer_index = (self.producer_index + 1) % self.num_counters
|
75
|
+
assert self.events[
|
76
|
+
self.producer_index
|
77
|
+
].finish_event.query(), (
|
78
|
+
"Producer finish event should be ready before being reused."
|
79
|
+
)
|
60
80
|
return self.producer_index
|
61
81
|
|
62
|
-
def set_consumer(self, index):
|
82
|
+
def set_consumer(self, index: int):
|
63
83
|
self.consumer_index = index
|
64
84
|
|
65
|
-
def
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
def wait_until(self, threshold):
|
71
|
-
with self.conditions[self.consumer_index]:
|
72
|
-
while self.counters[self.consumer_index] <= threshold:
|
73
|
-
self.conditions[self.consumer_index].wait()
|
85
|
+
def wait_until(self, threshold: int):
|
86
|
+
if self.consumer_index < 0:
|
87
|
+
return
|
88
|
+
self.events[self.consumer_index].wait(threshold)
|
74
89
|
|
75
90
|
def reset(self):
|
76
|
-
|
77
|
-
|
91
|
+
self.producer_index = -1
|
92
|
+
self.consumer_index = -1
|
78
93
|
|
79
94
|
|
80
95
|
class CacheOperation:
|
@@ -98,36 +113,30 @@ class CacheOperation:
|
|
98
113
|
# default priority is the order of creation
|
99
114
|
self.priority = priority if priority is not None else self.id
|
100
115
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
device_indices=self.device_indices[i : i + chunk_size],
|
120
|
-
node_id=0,
|
121
|
-
)
|
122
|
-
)
|
123
|
-
# Inherit the node_ids on the final chunk
|
124
|
-
if split_ops:
|
125
|
-
split_ops[-1].node_ids = self.node_ids
|
116
|
+
@staticmethod
|
117
|
+
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
118
|
+
assert len(ops) > 0
|
119
|
+
if len(ops) == 1:
|
120
|
+
return ops[0]
|
121
|
+
|
122
|
+
host_indices = torch.cat([op.host_indices for op in ops])
|
123
|
+
device_indices = torch.cat([op.device_indices for op in ops])
|
124
|
+
node_ids = []
|
125
|
+
priority = min(op.priority for op in ops)
|
126
|
+
for op in ops:
|
127
|
+
node_ids.extend(op.node_ids)
|
128
|
+
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
129
|
+
merged_op.node_ids = node_ids
|
130
|
+
return merged_op
|
131
|
+
|
132
|
+
def __lt__(self, other: CacheOperation):
|
133
|
+
return self.priority < other.priority
|
126
134
|
|
127
|
-
return split_ops
|
128
135
|
|
129
|
-
|
130
|
-
|
136
|
+
class HiCacheAck(NamedTuple):
|
137
|
+
start_event: torch.cuda.Event
|
138
|
+
finish_event: torch.cuda.Event
|
139
|
+
node_ids: List[int]
|
131
140
|
|
132
141
|
|
133
142
|
class TransferBuffer:
|
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
|
|
206
215
|
):
|
207
216
|
self.request_id = request_id
|
208
217
|
|
209
|
-
self._done_flag = False
|
210
218
|
self._lock = threading.Lock()
|
211
|
-
|
219
|
+
self._terminated_flag = False
|
212
220
|
self.start_time = time.monotonic()
|
213
221
|
|
214
222
|
super().__init__(host_indices, token_ids, last_hash)
|
215
223
|
|
216
224
|
def increment(self, num_tokens: int):
|
217
225
|
with self._lock:
|
218
|
-
if self.
|
226
|
+
if self._terminated_flag:
|
219
227
|
return False
|
220
228
|
self.completed_tokens += num_tokens
|
221
229
|
return True
|
222
230
|
|
223
|
-
def
|
231
|
+
def mark_terminate(self):
|
224
232
|
with self._lock:
|
225
|
-
self.
|
233
|
+
self._terminated_flag = True
|
226
234
|
|
227
|
-
def
|
228
|
-
return self.
|
235
|
+
def is_terminated(self) -> bool:
|
236
|
+
return self._terminated_flag
|
229
237
|
|
230
238
|
|
231
239
|
class HiCacheController:
|
@@ -236,7 +244,7 @@ class HiCacheController:
|
|
236
244
|
mem_pool_host: HostKVCache,
|
237
245
|
page_size: int,
|
238
246
|
tp_group: torch.distributed.ProcessGroup,
|
239
|
-
load_cache_event: threading.Event
|
247
|
+
load_cache_event: threading.Event,
|
240
248
|
write_policy: str = "write_through_selective",
|
241
249
|
io_backend: str = "",
|
242
250
|
storage_backend: Optional[str] = None,
|
@@ -250,26 +258,21 @@ class HiCacheController:
|
|
250
258
|
self.write_policy = write_policy
|
251
259
|
self.page_size = page_size
|
252
260
|
self.io_backend = io_backend
|
253
|
-
|
254
261
|
self.enable_storage = False
|
255
262
|
|
256
|
-
# todo: move backend initialization to storage backend module
|
257
263
|
if storage_backend is not None:
|
258
264
|
self.storage_backend_type = storage_backend
|
259
265
|
from sglang.srt.mem_cache.hicache_storage import get_hash_str
|
260
266
|
|
261
267
|
self.get_hash_str = get_hash_str
|
262
|
-
|
263
268
|
self.storage_config = self._generate_storage_config(
|
264
269
|
model_name, storage_backend_extra_config
|
265
270
|
)
|
266
|
-
#
|
271
|
+
# for MLA models, only one rank needs to backup the KV cache
|
267
272
|
self.backup_skip = (
|
268
273
|
self.storage_config.is_mla_model
|
269
|
-
# todo:
|
274
|
+
# todo: load balancing
|
270
275
|
and self.storage_config.tp_rank != 0
|
271
|
-
# todo: support other storage backends
|
272
|
-
and self.storage_backend_type in ["file", "mooncake"]
|
273
276
|
)
|
274
277
|
|
275
278
|
if storage_backend == "file":
|
@@ -309,12 +312,15 @@ class HiCacheController:
|
|
309
312
|
raise NotImplementedError(
|
310
313
|
f"Unsupported storage backend: {storage_backend}"
|
311
314
|
)
|
315
|
+
|
312
316
|
self.enable_storage = True
|
313
317
|
# todo: threshold policy for prefetching
|
314
318
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
315
319
|
self.prefetch_capacity_limit = int(
|
316
320
|
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
317
321
|
)
|
322
|
+
# granularity of batch storage IO operations, in number of pages
|
323
|
+
self.storage_batch_size = 128
|
318
324
|
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
319
325
|
self.prefetch_tokens_occupied = 0
|
320
326
|
|
@@ -325,15 +331,26 @@ class HiCacheController:
|
|
325
331
|
self.prefetch_tp_group = torch.distributed.new_group(
|
326
332
|
group_ranks, backend="gloo"
|
327
333
|
)
|
328
|
-
self.prefetch_io_tp_group = torch.distributed.new_group(
|
329
|
-
group_ranks, backend="gloo"
|
330
|
-
)
|
331
|
-
self.backup_tp_group = torch.distributed.new_group(
|
332
|
-
group_ranks, backend="gloo"
|
333
|
-
)
|
334
334
|
|
335
|
-
|
336
|
-
|
335
|
+
# Select the get and set functions
|
336
|
+
self.page_get_func = self._generic_page_get
|
337
|
+
self.page_set_func = self._generic_page_set
|
338
|
+
self.batch_exists_func = self.storage_backend.batch_exists
|
339
|
+
self.is_3fs_zerocopy = (
|
340
|
+
self.storage_backend_type == "hf3fs"
|
341
|
+
and self.mem_pool_host.layout == "page_first"
|
342
|
+
)
|
343
|
+
if self.storage_backend_type == "mooncake":
|
344
|
+
self.page_get_func = self._mooncake_page_get
|
345
|
+
self.page_set_func = self._mooncake_page_set
|
346
|
+
elif self.is_3fs_zerocopy:
|
347
|
+
self.page_get_func = self._3fs_zero_copy_page_get
|
348
|
+
self.page_set_func = self._3fs_zero_copy_page_set
|
349
|
+
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
350
|
+
|
351
|
+
self.device = self.mem_pool_device.device
|
352
|
+
self.layer_num = self.mem_pool_device.layer_num
|
353
|
+
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
337
354
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
338
355
|
|
339
356
|
if write_policy not in [
|
@@ -343,11 +360,11 @@ class HiCacheController:
|
|
343
360
|
]:
|
344
361
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
345
362
|
|
346
|
-
self.write_queue = PriorityQueue()
|
347
|
-
self.load_queue =
|
348
|
-
|
349
|
-
self.
|
350
|
-
self.
|
363
|
+
# self.write_queue = PriorityQueue[CacheOperation]()
|
364
|
+
self.load_queue: List[CacheOperation] = []
|
365
|
+
self.write_queue: List[CacheOperation] = []
|
366
|
+
self.ack_load_queue: List[HiCacheAck] = []
|
367
|
+
self.ack_write_queue: List[HiCacheAck] = []
|
351
368
|
|
352
369
|
self.stop_event = threading.Event()
|
353
370
|
self.write_buffer = TransferBuffer(self.stop_event)
|
@@ -358,16 +375,6 @@ class HiCacheController:
|
|
358
375
|
self.write_stream = torch.cuda.Stream()
|
359
376
|
self.load_stream = torch.cuda.Stream()
|
360
377
|
|
361
|
-
self.write_thread = threading.Thread(
|
362
|
-
target=self.write_thread_func_direct, daemon=True
|
363
|
-
)
|
364
|
-
self.load_thread = threading.Thread(
|
365
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
366
|
-
)
|
367
|
-
|
368
|
-
self.write_thread.start()
|
369
|
-
self.load_thread.start()
|
370
|
-
|
371
378
|
if self.enable_storage:
|
372
379
|
self.prefetch_thread = threading.Thread(
|
373
380
|
target=self.prefetch_thread_func, daemon=True
|
@@ -380,6 +387,7 @@ class HiCacheController:
|
|
380
387
|
|
381
388
|
self.prefetch_revoke_queue = Queue()
|
382
389
|
self.ack_backup_queue = Queue()
|
390
|
+
self.host_mem_release_queue = Queue()
|
383
391
|
|
384
392
|
self.prefetch_thread.start()
|
385
393
|
self.backup_thread.start()
|
@@ -393,9 +401,11 @@ class HiCacheController:
|
|
393
401
|
if is_dp_attention_enabled():
|
394
402
|
self.tp_rank = get_attention_tp_rank()
|
395
403
|
self.tp_size = get_attention_tp_size()
|
404
|
+
self.dp_rank = get_attention_dp_rank()
|
396
405
|
else:
|
397
406
|
self.tp_rank = get_tensor_model_parallel_rank()
|
398
407
|
self.tp_size = get_tensor_model_parallel_world_size()
|
408
|
+
self.dp_rank = 0
|
399
409
|
|
400
410
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
401
411
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
@@ -414,21 +424,20 @@ class HiCacheController:
|
|
414
424
|
tp_rank=self.tp_rank,
|
415
425
|
tp_size=self.tp_size,
|
416
426
|
is_mla_model=is_mla_backend,
|
427
|
+
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
417
428
|
model_name=model_name,
|
418
429
|
extra_config=extra_config,
|
419
430
|
)
|
420
431
|
|
421
432
|
def reset(self):
|
422
433
|
self.stop_event.set()
|
423
|
-
self.write_thread.join()
|
424
|
-
self.load_thread.join()
|
425
434
|
|
426
|
-
self.write_queue.
|
427
|
-
self.load_queue.
|
435
|
+
self.write_queue.clear()
|
436
|
+
self.load_queue.clear()
|
428
437
|
self.write_buffer.clear()
|
429
438
|
self.load_buffer.clear()
|
430
|
-
self.ack_write_queue.
|
431
|
-
self.ack_load_queue.
|
439
|
+
self.ack_write_queue.clear()
|
440
|
+
self.ack_load_queue.clear()
|
432
441
|
if self.enable_storage:
|
433
442
|
self.prefetch_thread.join()
|
434
443
|
self.backup_thread.join()
|
@@ -437,15 +446,7 @@ class HiCacheController:
|
|
437
446
|
self.prefetch_revoke_queue.queue.clear()
|
438
447
|
self.ack_backup_queue.queue.clear()
|
439
448
|
|
440
|
-
self.write_thread = threading.Thread(
|
441
|
-
target=self.write_thread_func_direct, daemon=True
|
442
|
-
)
|
443
|
-
self.load_thread = threading.Thread(
|
444
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
445
|
-
)
|
446
449
|
self.stop_event.clear()
|
447
|
-
self.write_thread.start()
|
448
|
-
self.load_thread.start()
|
449
450
|
|
450
451
|
if self.enable_storage:
|
451
452
|
self.prefetch_thread = threading.Thread(
|
@@ -461,7 +462,7 @@ class HiCacheController:
|
|
461
462
|
self,
|
462
463
|
device_indices: torch.Tensor,
|
463
464
|
priority: Optional[int] = None,
|
464
|
-
node_id: int =
|
465
|
+
node_id: int = -1,
|
465
466
|
) -> Optional[torch.Tensor]:
|
466
467
|
"""
|
467
468
|
Back up KV caches from device memory to host memory.
|
@@ -470,17 +471,46 @@ class HiCacheController:
|
|
470
471
|
if host_indices is None:
|
471
472
|
return None
|
472
473
|
self.mem_pool_host.protect_write(host_indices)
|
473
|
-
|
474
|
-
self.write_queue.put(
|
474
|
+
self.write_queue.append(
|
475
475
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
476
476
|
)
|
477
|
+
self.start_writing()
|
477
478
|
return host_indices
|
478
479
|
|
480
|
+
def start_writing(self) -> None:
|
481
|
+
if len(self.write_queue) == 0:
|
482
|
+
return
|
483
|
+
|
484
|
+
op = CacheOperation.merge_ops(self.write_queue)
|
485
|
+
host_indices, device_indices = self.move_indices(op)
|
486
|
+
self.write_queue.clear()
|
487
|
+
|
488
|
+
start_event = torch.cuda.Event()
|
489
|
+
finish_event = torch.cuda.Event()
|
490
|
+
|
491
|
+
start_event.record()
|
492
|
+
with torch.cuda.stream(self.write_stream):
|
493
|
+
start_event.wait(self.write_stream)
|
494
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
495
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
496
|
+
)
|
497
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
498
|
+
finish_event.record()
|
499
|
+
# NOTE: We must save the host indices and device indices here,
|
500
|
+
# this is because we need to guarantee that these tensors are
|
501
|
+
# still alive when the write stream is executing.
|
502
|
+
if host_indices.is_cuda:
|
503
|
+
host_indices.record_stream(self.write_stream)
|
504
|
+
if device_indices.is_cuda:
|
505
|
+
device_indices.record_stream(self.write_stream)
|
506
|
+
|
507
|
+
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
508
|
+
|
479
509
|
def load(
|
480
510
|
self,
|
481
511
|
host_indices: torch.Tensor,
|
482
512
|
priority: Optional[int] = None,
|
483
|
-
node_id: int =
|
513
|
+
node_id: int = -1,
|
484
514
|
) -> Optional[torch.Tensor]:
|
485
515
|
"""
|
486
516
|
Load KV caches from host memory to device memory.
|
@@ -489,17 +519,18 @@ class HiCacheController:
|
|
489
519
|
if device_indices is None:
|
490
520
|
return None
|
491
521
|
self.mem_pool_host.protect_load(host_indices)
|
492
|
-
|
493
|
-
torch.cuda.current_stream().synchronize()
|
494
|
-
self.load_queue.put(
|
522
|
+
self.load_queue.append(
|
495
523
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
496
524
|
)
|
497
525
|
return device_indices
|
498
526
|
|
499
|
-
def move_indices(self,
|
527
|
+
def move_indices(self, op: CacheOperation):
|
528
|
+
host_indices, device_indices = op.host_indices, op.device_indices
|
500
529
|
# move indices to GPU if using kernels, to host if using direct indexing
|
501
530
|
if self.io_backend == "kernel":
|
502
|
-
|
531
|
+
if not host_indices.is_cuda:
|
532
|
+
host_indices = host_indices.to(self.device, non_blocking=True)
|
533
|
+
return host_indices, device_indices
|
503
534
|
elif self.io_backend == "direct":
|
504
535
|
device_indices = device_indices.cpu()
|
505
536
|
host_indices, idx = host_indices.sort()
|
@@ -507,58 +538,20 @@ class HiCacheController:
|
|
507
538
|
else:
|
508
539
|
raise ValueError(f"Unsupported io backend")
|
509
540
|
|
510
|
-
def
|
511
|
-
|
512
|
-
|
513
|
-
"""
|
514
|
-
torch.cuda.set_stream(self.write_stream)
|
515
|
-
while not self.stop_event.is_set():
|
516
|
-
try:
|
517
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
518
|
-
host_indices, device_indices = self.move_indices(
|
519
|
-
operation.host_indices, operation.device_indices
|
520
|
-
)
|
521
|
-
self.mem_pool_host.backup_from_device_all_layer(
|
522
|
-
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
523
|
-
)
|
524
|
-
self.write_stream.synchronize()
|
525
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
526
|
-
for node_id in operation.node_ids:
|
527
|
-
if node_id != 0:
|
528
|
-
self.ack_write_queue.put(node_id)
|
529
|
-
except Empty:
|
530
|
-
continue
|
531
|
-
except Exception as e:
|
532
|
-
logger.error(e)
|
541
|
+
def start_loading(self) -> int:
|
542
|
+
if len(self.load_queue) == 0:
|
543
|
+
return -1
|
533
544
|
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
self.load_cache_event.wait(timeout=1)
|
541
|
-
if not self.load_cache_event.is_set():
|
542
|
-
continue
|
543
|
-
self.load_cache_event.clear()
|
544
|
-
self.layer_done_counter.update_producer()
|
545
|
-
|
546
|
-
batch_operation = None
|
547
|
-
while self.load_queue.qsize() > 0:
|
548
|
-
op = self.load_queue.get(block=True)
|
549
|
-
if batch_operation is None:
|
550
|
-
batch_operation = op
|
551
|
-
else:
|
552
|
-
batch_operation.merge(op)
|
553
|
-
if batch_operation is None:
|
554
|
-
continue
|
545
|
+
producer_id = self.layer_done_counter.update_producer()
|
546
|
+
op = CacheOperation.merge_ops(self.load_queue)
|
547
|
+
host_indices, device_indices = self.move_indices(op)
|
548
|
+
self.load_queue.clear()
|
549
|
+
producer_event = self.layer_done_counter.events[producer_id]
|
550
|
+
producer_event.start_event.record()
|
555
551
|
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
batch_operation.host_indices, batch_operation.device_indices
|
560
|
-
)
|
561
|
-
for i in range(self.mem_pool_host.layer_num):
|
552
|
+
with torch.cuda.stream(self.load_stream):
|
553
|
+
producer_event.start_event.wait(self.load_stream)
|
554
|
+
for i in range(self.layer_num):
|
562
555
|
self.mem_pool_host.load_to_device_per_layer(
|
563
556
|
self.mem_pool_device,
|
564
557
|
host_indices,
|
@@ -566,13 +559,24 @@ class HiCacheController:
|
|
566
559
|
i,
|
567
560
|
self.io_backend,
|
568
561
|
)
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
562
|
+
producer_event.complete(i)
|
563
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
564
|
+
# NOTE: We must save the host indices and device indices here,
|
565
|
+
# this is because we need to guarantee that these tensors are
|
566
|
+
# still alive when the load stream is executing.
|
567
|
+
if host_indices.is_cuda:
|
568
|
+
host_indices.record_stream(self.load_stream)
|
569
|
+
if device_indices.is_cuda:
|
570
|
+
device_indices.record_stream(self.load_stream)
|
571
|
+
|
572
|
+
self.ack_load_queue.append(
|
573
|
+
HiCacheAck(
|
574
|
+
start_event=producer_event.start_event,
|
575
|
+
finish_event=producer_event.finish_event,
|
576
|
+
node_ids=op.node_ids,
|
577
|
+
)
|
578
|
+
)
|
579
|
+
return producer_id
|
576
580
|
|
577
581
|
def evict_device(
|
578
582
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
@@ -615,31 +619,41 @@ class HiCacheController:
|
|
615
619
|
return operation
|
616
620
|
|
617
621
|
def terminate_prefetch(self, operation):
|
618
|
-
operation.
|
622
|
+
operation.mark_terminate()
|
619
623
|
return operation.completed_tokens, operation.hash_value
|
620
624
|
|
621
|
-
|
625
|
+
def append_host_mem_release(self, host_indices: torch.Tensor):
|
626
|
+
chunks = host_indices.split(self.mem_pool_host.page_size)
|
627
|
+
for chunk in chunks:
|
628
|
+
self.host_mem_release_queue.put(chunk)
|
629
|
+
|
630
|
+
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
631
|
+
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
632
|
+
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
633
|
+
return hit_page_num
|
634
|
+
|
622
635
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
623
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
636
|
+
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
624
637
|
hash_values, host_indices
|
625
638
|
)
|
626
639
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
627
640
|
if page_data:
|
628
|
-
|
641
|
+
inc = self.page_size * len(hashes) // factor
|
642
|
+
operation.increment(inc)
|
629
643
|
else:
|
630
644
|
logger.warning(
|
631
645
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
632
646
|
)
|
633
647
|
|
634
|
-
# zero copy
|
635
648
|
def _mooncake_page_get(self, operation, hash_values, host_indices):
|
636
649
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
637
650
|
hash_values,
|
638
651
|
host_indices,
|
652
|
+
self.storage_config.tp_rank,
|
639
653
|
)
|
640
654
|
get_result = self.storage_backend.batch_get(
|
641
655
|
key_strs,
|
642
|
-
|
656
|
+
target_locations=buffer_ptrs,
|
643
657
|
target_sizes=buffer_sizes,
|
644
658
|
)
|
645
659
|
if get_result != len(hash_values):
|
@@ -649,12 +663,10 @@ class HiCacheController:
|
|
649
663
|
if get_result != 0:
|
650
664
|
operation.increment(get_result * self.page_size)
|
651
665
|
|
652
|
-
# non-zero copy
|
653
666
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
)
|
667
|
+
dummy_page_dst = [
|
668
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
669
|
+
]
|
658
670
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
659
671
|
if page_data is None:
|
660
672
|
return
|
@@ -664,49 +676,36 @@ class HiCacheController:
|
|
664
676
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
665
677
|
)
|
666
678
|
break
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
679
|
+
# Must set the data before increasing the completed tokens.
|
680
|
+
# Otherwise this page may be read before being set.
|
681
|
+
self.mem_pool_host.set_from_flat_data_page(
|
682
|
+
host_indices[i * self.page_size],
|
683
|
+
page_data[i],
|
684
|
+
)
|
685
|
+
if not operation.increment(self.page_size):
|
686
|
+
break # Operation terminated by controller
|
674
687
|
|
675
688
|
def _page_transfer(self, operation):
|
676
|
-
# Select the get function and batch size
|
677
|
-
if self.is_mooncake_backend():
|
678
|
-
get_func = self._mooncake_page_get
|
679
|
-
batch_size = 128
|
680
|
-
elif self.storage_backend_type == "hf3fs":
|
681
|
-
if self.mem_pool_host.layout == "page_first":
|
682
|
-
get_func = self._3fs_zero_copy_page_get
|
683
|
-
elif self.mem_pool_host.layout == "layer_first":
|
684
|
-
get_func = self._generic_page_get
|
685
|
-
batch_size = 128
|
686
|
-
else:
|
687
|
-
get_func = self._generic_page_get
|
688
|
-
batch_size = 8
|
689
|
-
|
690
689
|
# Transfer batch by batch
|
691
|
-
for i in range(0, len(operation.hash_value),
|
692
|
-
batch_hashes = operation.hash_value[i : i +
|
690
|
+
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
691
|
+
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
693
692
|
batch_host_indices = operation.host_indices[
|
694
693
|
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
695
694
|
]
|
696
695
|
prev_completed_tokens = operation.completed_tokens
|
697
696
|
# Get one batch token, and update the completed_tokens if succeed
|
698
|
-
|
697
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
699
698
|
# Check termination
|
700
699
|
if (
|
701
700
|
operation.completed_tokens
|
702
701
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
703
702
|
):
|
703
|
+
operation.mark_terminate()
|
704
704
|
break # Some operations fail or operation terminated by controller
|
705
705
|
# release pre-allocated memory
|
706
|
-
self.
|
707
|
-
|
708
|
-
|
709
|
-
return self.storage_backend_type == "mooncake"
|
706
|
+
self.append_host_mem_release(
|
707
|
+
operation.host_indices[operation.completed_tokens :]
|
708
|
+
)
|
710
709
|
|
711
710
|
def prefetch_io_aux_func(self):
|
712
711
|
"""
|
@@ -716,47 +715,49 @@ class HiCacheController:
|
|
716
715
|
try:
|
717
716
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
718
717
|
self._page_transfer(operation)
|
719
|
-
|
720
|
-
if self.tp_world_size > 1:
|
721
|
-
# to ensure all TP workers release the host memory at the same time
|
722
|
-
torch.distributed.barrier(group=self.prefetch_io_tp_group)
|
723
718
|
# operation terminated by controller, release pre-allocated memory
|
724
|
-
self.
|
719
|
+
self.append_host_mem_release(
|
725
720
|
operation.host_indices[operation.completed_tokens :]
|
726
721
|
)
|
727
722
|
except Empty:
|
728
723
|
continue
|
729
724
|
|
730
|
-
def
|
725
|
+
def prefetch_rate_limited(self) -> bool:
|
731
726
|
"""
|
732
727
|
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
733
728
|
"""
|
734
729
|
# cancel prefetch if too much memory is occupied
|
735
730
|
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
736
|
-
return
|
731
|
+
return True
|
737
732
|
# todo: more sophisticated rate limiting based on storage backend performance
|
738
|
-
return
|
733
|
+
return False
|
739
734
|
|
740
|
-
def
|
735
|
+
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
|
741
736
|
last_hash = operation.last_hash
|
742
737
|
tokens_to_fetch = operation.token_ids
|
743
738
|
|
744
739
|
storage_query_count = 0
|
745
|
-
remaining_tokens = len(tokens_to_fetch)
|
746
740
|
hash_value = []
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
741
|
+
|
742
|
+
for start in range(
|
743
|
+
0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
|
744
|
+
):
|
745
|
+
end = min(
|
746
|
+
start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
|
753
747
|
)
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
748
|
+
batch_tokens = tokens_to_fetch[start:end]
|
749
|
+
batch_hashes = []
|
750
|
+
for i in range(0, len(batch_tokens), self.page_size):
|
751
|
+
last_hash = self.get_hash_str(
|
752
|
+
batch_tokens[i : i + self.page_size], last_hash
|
753
|
+
)
|
754
|
+
batch_hashes.append(last_hash)
|
755
|
+
hit_page_num = self.batch_exists_func(batch_hashes)
|
756
|
+
hash_value.extend(batch_hashes[:hit_page_num])
|
757
|
+
storage_query_count += hit_page_num * self.page_size
|
758
|
+
if hit_page_num < len(batch_hashes):
|
759
|
+
break
|
760
|
+
return hash_value, storage_query_count
|
760
761
|
|
761
762
|
def prefetch_thread_func(self):
|
762
763
|
"""
|
@@ -771,13 +772,7 @@ class HiCacheController:
|
|
771
772
|
if operation is None:
|
772
773
|
continue
|
773
774
|
|
774
|
-
|
775
|
-
operation.host_indices is not None
|
776
|
-
) and self.prefetch_rate_limit_check():
|
777
|
-
hash_value, storage_hit_count = self._generic_storage_hit_query(
|
778
|
-
operation
|
779
|
-
)
|
780
|
-
|
775
|
+
hash_value, storage_hit_count = self._storage_hit_query(operation)
|
781
776
|
if self.tp_world_size > 1:
|
782
777
|
storage_hit_count_tensor = torch.tensor(
|
783
778
|
storage_hit_count, dtype=torch.int
|
@@ -792,8 +787,7 @@ class HiCacheController:
|
|
792
787
|
if storage_hit_count < self.prefetch_threshold:
|
793
788
|
# not to prefetch if not enough benefits
|
794
789
|
self.prefetch_revoke_queue.put(operation.request_id)
|
795
|
-
|
796
|
-
self.mem_pool_host.free(operation.host_indices)
|
790
|
+
self.append_host_mem_release(operation.host_indices)
|
797
791
|
logger.debug(
|
798
792
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
799
793
|
)
|
@@ -802,7 +796,9 @@ class HiCacheController:
|
|
802
796
|
: (storage_hit_count // self.page_size)
|
803
797
|
]
|
804
798
|
# free the pre-allocated memory for pages that are not hit
|
805
|
-
self.
|
799
|
+
self.append_host_mem_release(
|
800
|
+
operation.host_indices[storage_hit_count:]
|
801
|
+
)
|
806
802
|
operation.host_indices = operation.host_indices[:storage_hit_count]
|
807
803
|
logger.debug(
|
808
804
|
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
@@ -838,45 +834,33 @@ class HiCacheController:
|
|
838
834
|
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
839
835
|
hash_values,
|
840
836
|
host_indices,
|
837
|
+
self.storage_config.tp_rank,
|
841
838
|
)
|
842
839
|
success = self.storage_backend.batch_set(
|
843
840
|
key_strs,
|
844
|
-
|
841
|
+
target_locations=buffer_ptrs,
|
845
842
|
target_sizes=buffer_sizes,
|
846
843
|
)
|
847
844
|
return success
|
848
845
|
|
849
846
|
# zero copy
|
850
847
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
851
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
848
|
+
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
852
849
|
hash_values, host_indices
|
853
850
|
)
|
854
851
|
return self.storage_backend.batch_set(hashes, dsts)
|
855
852
|
|
856
853
|
# Backup batch by batch
|
857
854
|
def _page_backup(self, operation):
|
858
|
-
# Select the set function and batch size
|
859
|
-
if self.is_mooncake_backend():
|
860
|
-
backup_set_func = self._mooncake_page_set
|
861
|
-
batch_size = 128
|
862
|
-
elif self.storage_backend_type == "hf3fs":
|
863
|
-
if self.mem_pool_host.layout == "page_first":
|
864
|
-
backup_set_func = self._3fs_zero_copy_page_set
|
865
|
-
elif self.mem_pool_host.layout == "layer_first":
|
866
|
-
backup_set_func = self._generic_page_set
|
867
|
-
batch_size = 128
|
868
|
-
else:
|
869
|
-
backup_set_func = self._generic_page_set
|
870
|
-
batch_size = 8
|
871
855
|
# Backup batch by batch
|
872
|
-
for i in range(0, len(operation.hash_value),
|
873
|
-
batch_hashes = operation.hash_value[i : i +
|
856
|
+
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
857
|
+
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
874
858
|
batch_host_indices = operation.host_indices[
|
875
859
|
i * self.page_size : (i + len(batch_hashes)) * self.page_size
|
876
860
|
]
|
877
861
|
# Set one batch token, and record if success.
|
878
862
|
# todo: allow partial success
|
879
|
-
success =
|
863
|
+
success = self.page_set_func(batch_hashes, batch_host_indices)
|
880
864
|
if not success:
|
881
865
|
logger.warning(
|
882
866
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
@@ -896,27 +880,7 @@ class HiCacheController:
|
|
896
880
|
|
897
881
|
if not self.backup_skip:
|
898
882
|
self._page_backup(operation)
|
899
|
-
|
900
|
-
else:
|
901
|
-
min_completed_tokens = len(operation.token_ids)
|
902
|
-
|
903
|
-
if self.tp_world_size > 1:
|
904
|
-
completed_tokens_tensor = torch.tensor(
|
905
|
-
min_completed_tokens, dtype=torch.int
|
906
|
-
)
|
907
|
-
torch.distributed.all_reduce(
|
908
|
-
completed_tokens_tensor,
|
909
|
-
op=torch.distributed.ReduceOp.MIN,
|
910
|
-
group=self.backup_tp_group,
|
911
|
-
)
|
912
|
-
min_completed_tokens = completed_tokens_tensor.item()
|
913
|
-
|
914
|
-
self.ack_backup_queue.put(
|
915
|
-
(
|
916
|
-
operation.id,
|
917
|
-
min_completed_tokens,
|
918
|
-
)
|
919
|
-
)
|
883
|
+
self.ack_backup_queue.put(operation)
|
920
884
|
|
921
885
|
except Empty:
|
922
886
|
continue
|