sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -31,27 +31,13 @@ if not (_is_npu or _is_xpu):
|
|
31
31
|
logger = logging.getLogger(__name__)
|
32
32
|
|
33
33
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
BACKUP = 4
|
40
|
-
|
41
|
-
|
42
|
-
def synchronized(debug_only=False):
|
43
|
-
def _decorator(func):
|
44
|
-
@wraps(func)
|
45
|
-
def wrapper(self, *args, **kwargs):
|
46
|
-
if (not debug_only) or self.debug:
|
47
|
-
with self.lock:
|
48
|
-
return func(self, *args, **kwargs)
|
49
|
-
else:
|
50
|
-
return True
|
51
|
-
|
52
|
-
return wrapper
|
34
|
+
def synchronized(func):
|
35
|
+
@wraps(func)
|
36
|
+
def wrapper(self, *args, **kwargs):
|
37
|
+
with self.lock:
|
38
|
+
return func(self, *args, **kwargs)
|
53
39
|
|
54
|
-
return
|
40
|
+
return wrapper
|
55
41
|
|
56
42
|
|
57
43
|
class HostKVCache(abc.ABC):
|
@@ -110,7 +96,6 @@ class HostKVCache(abc.ABC):
|
|
110
96
|
|
111
97
|
# A lock for synchronized operations on memory allocation and state transitions.
|
112
98
|
self.lock = threading.RLock()
|
113
|
-
self.debug = logger.isEnabledFor(logging.DEBUG)
|
114
99
|
self.clear()
|
115
100
|
|
116
101
|
@abc.abstractmethod
|
@@ -140,7 +125,7 @@ class HostKVCache(abc.ABC):
|
|
140
125
|
raise NotImplementedError()
|
141
126
|
|
142
127
|
@abc.abstractmethod
|
143
|
-
def
|
128
|
+
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
144
129
|
"""
|
145
130
|
Get a flat data page from the host memory pool.
|
146
131
|
"""
|
@@ -161,7 +146,7 @@ class HostKVCache(abc.ABC):
|
|
161
146
|
"""
|
162
147
|
raise NotImplementedError()
|
163
148
|
|
164
|
-
@synchronized
|
149
|
+
@synchronized
|
165
150
|
def clear(self):
|
166
151
|
# Initialize memory states and tracking structures.
|
167
152
|
self.mem_state = torch.zeros(
|
@@ -172,7 +157,7 @@ class HostKVCache(abc.ABC):
|
|
172
157
|
def available_size(self):
|
173
158
|
return len(self.free_slots)
|
174
159
|
|
175
|
-
@synchronized
|
160
|
+
@synchronized
|
176
161
|
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
177
162
|
assert (
|
178
163
|
need_size % self.page_size == 0
|
@@ -183,92 +168,13 @@ class HostKVCache(abc.ABC):
|
|
183
168
|
select_index = self.free_slots[:need_size]
|
184
169
|
self.free_slots = self.free_slots[need_size:]
|
185
170
|
|
186
|
-
if self.debug:
|
187
|
-
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
188
|
-
|
189
171
|
return select_index
|
190
172
|
|
191
|
-
@synchronized
|
173
|
+
@synchronized
|
192
174
|
def free(self, indices: torch.Tensor) -> int:
|
193
175
|
self.free_slots = torch.cat([self.free_slots, indices])
|
194
|
-
if self.debug:
|
195
|
-
self.mem_state[indices] = MemoryStateInt.IDLE
|
196
176
|
return len(indices)
|
197
177
|
|
198
|
-
@synchronized(debug_only=True)
|
199
|
-
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
200
|
-
assert len(indices) > 0, "The indices should not be empty"
|
201
|
-
states = self.mem_state[indices]
|
202
|
-
assert (
|
203
|
-
states == states[0]
|
204
|
-
).all(), "The memory slots should have the same state {}".format(states)
|
205
|
-
return MemoryStateInt(states[0].item())
|
206
|
-
|
207
|
-
@synchronized(debug_only=True)
|
208
|
-
def is_reserved(self, indices: torch.Tensor) -> bool:
|
209
|
-
return self.get_state(indices) == MemoryStateInt.RESERVED
|
210
|
-
|
211
|
-
@synchronized(debug_only=True)
|
212
|
-
def is_protected(self, indices: torch.Tensor) -> bool:
|
213
|
-
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
214
|
-
|
215
|
-
@synchronized(debug_only=True)
|
216
|
-
def is_synced(self, indices: torch.Tensor) -> bool:
|
217
|
-
return self.get_state(indices) == MemoryStateInt.SYNCED
|
218
|
-
|
219
|
-
@synchronized(debug_only=True)
|
220
|
-
def is_backup(self, indices: torch.Tensor) -> bool:
|
221
|
-
return self.get_state(indices) == MemoryStateInt.BACKUP
|
222
|
-
|
223
|
-
@synchronized(debug_only=True)
|
224
|
-
def update_backup(self, indices: torch.Tensor):
|
225
|
-
if not self.is_synced(indices):
|
226
|
-
raise ValueError(
|
227
|
-
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
228
|
-
f"Current state: {self.get_state(indices)}"
|
229
|
-
)
|
230
|
-
self.mem_state[indices] = MemoryStateInt.BACKUP
|
231
|
-
|
232
|
-
@synchronized(debug_only=True)
|
233
|
-
def update_prefetch(self, indices: torch.Tensor):
|
234
|
-
if not self.is_reserved(indices):
|
235
|
-
raise ValueError(
|
236
|
-
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
237
|
-
f"Current state: {self.get_state(indices)}"
|
238
|
-
)
|
239
|
-
self.mem_state[indices] = MemoryStateInt.BACKUP
|
240
|
-
|
241
|
-
@synchronized(debug_only=True)
|
242
|
-
def update_synced(self, indices: torch.Tensor):
|
243
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
244
|
-
|
245
|
-
@synchronized(debug_only=True)
|
246
|
-
def protect_write(self, indices: torch.Tensor):
|
247
|
-
if not self.is_reserved(indices):
|
248
|
-
raise ValueError(
|
249
|
-
f"The host memory slots should be RESERVED before write operations. "
|
250
|
-
f"Current state: {self.get_state(indices)}"
|
251
|
-
)
|
252
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
253
|
-
|
254
|
-
@synchronized(debug_only=True)
|
255
|
-
def protect_load(self, indices: torch.Tensor):
|
256
|
-
if not self.is_backup(indices):
|
257
|
-
raise ValueError(
|
258
|
-
f"The host memory slots should be in BACKUP state before load operations. "
|
259
|
-
f"Current state: {self.get_state(indices)}"
|
260
|
-
)
|
261
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
262
|
-
|
263
|
-
@synchronized(debug_only=True)
|
264
|
-
def complete_io(self, indices: torch.Tensor):
|
265
|
-
if not self.is_protected(indices):
|
266
|
-
raise ValueError(
|
267
|
-
f"The host memory slots should be PROTECTED during I/O operations. "
|
268
|
-
f"Current state: {self.get_state(indices)}"
|
269
|
-
)
|
270
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
271
|
-
|
272
178
|
|
273
179
|
class MHATokenToKVPoolHost(HostKVCache):
|
274
180
|
device_pool: MHATokenToKVPool
|
@@ -461,13 +367,19 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
461
367
|
else:
|
462
368
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
463
369
|
|
464
|
-
def
|
370
|
+
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
465
371
|
if self.layout == "layer_first":
|
466
|
-
|
372
|
+
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
|
467
373
|
elif self.layout == "page_first":
|
468
|
-
|
374
|
+
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
|
375
|
+
elif self.layout == "page_first_direct":
|
376
|
+
real_index = index // self.page_size
|
377
|
+
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
|
469
378
|
else:
|
470
379
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
380
|
+
if flat:
|
381
|
+
data_page = data_page.flatten()
|
382
|
+
return data_page
|
471
383
|
|
472
384
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
473
385
|
return torch.zeros(
|
@@ -494,12 +406,22 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
494
406
|
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
495
407
|
)
|
496
408
|
)
|
409
|
+
elif self.layout == "page_first_direct":
|
410
|
+
real_index = index // self.page_size
|
411
|
+
self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
|
412
|
+
data_page.reshape(
|
413
|
+
2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
|
414
|
+
)
|
415
|
+
)
|
497
416
|
else:
|
498
417
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
499
418
|
|
500
|
-
def
|
419
|
+
def get_page_buffer_meta(self, indices):
|
420
|
+
""" "
|
421
|
+
meta data for zero copy
|
422
|
+
"""
|
423
|
+
assert len(indices) % self.page_size == 0
|
501
424
|
ptr_list = []
|
502
|
-
key_list = []
|
503
425
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
504
426
|
indices = indices.tolist()
|
505
427
|
v_offset = (
|
@@ -509,48 +431,52 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
509
431
|
* self.head_dim
|
510
432
|
* self.dtype.itemsize
|
511
433
|
)
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
434
|
+
if self.layout == "layer_first":
|
435
|
+
for index in range(0, len(indices), self.page_size):
|
436
|
+
for layer_id in range(self.layer_num):
|
437
|
+
k_ptr = (
|
438
|
+
kv_buffer_data_ptr
|
439
|
+
+ indices[index]
|
440
|
+
* self.head_num
|
441
|
+
* self.head_dim
|
442
|
+
* self.dtype.itemsize
|
443
|
+
+ layer_id
|
444
|
+
* self.size
|
445
|
+
* self.head_num
|
446
|
+
* self.head_dim
|
447
|
+
* self.dtype.itemsize
|
448
|
+
)
|
449
|
+
v_ptr = k_ptr + v_offset
|
450
|
+
ptr_list.append(k_ptr)
|
451
|
+
ptr_list.append(v_ptr)
|
452
|
+
element_size = (
|
453
|
+
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
454
|
+
)
|
455
|
+
element_size_list = [element_size] * len(ptr_list)
|
456
|
+
elif self.layout in ["page_first", "page_first_direct"]:
|
457
|
+
for index in range(0, len(indices), self.page_size):
|
458
|
+
k_ptr = (
|
459
|
+
kv_buffer_data_ptr
|
460
|
+
+ indices[index]
|
461
|
+
* self.layer_num
|
462
|
+
* self.head_num
|
463
|
+
* self.head_dim
|
464
|
+
* self.dtype.itemsize
|
465
|
+
)
|
466
|
+
v_ptr = k_ptr + v_offset
|
467
|
+
ptr_list.append(k_ptr)
|
468
|
+
ptr_list.append(v_ptr)
|
469
|
+
element_size = (
|
470
|
+
self.layer_num
|
471
|
+
* self.dtype.itemsize
|
472
|
+
* self.page_size
|
517
473
|
* self.head_num
|
518
474
|
* self.head_dim
|
519
|
-
* self.dtype.itemsize
|
520
475
|
)
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
key_list.append(f"{key_}_{local_rank}_k")
|
526
|
-
key_list.append(f"{key_}_{local_rank}_v")
|
527
|
-
element_size = (
|
528
|
-
self.layer_num
|
529
|
-
* self.dtype.itemsize
|
530
|
-
* self.page_size
|
531
|
-
* self.head_num
|
532
|
-
* self.head_dim
|
533
|
-
)
|
534
|
-
element_size_list = [element_size] * len(key_list)
|
535
|
-
return key_list, ptr_list, element_size_list
|
536
|
-
|
537
|
-
def get_buffer_with_hash(self, keys, indices=None):
|
538
|
-
assert self.layout == "page_first"
|
539
|
-
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
540
|
-
|
541
|
-
key_list = []
|
542
|
-
buf_list = []
|
543
|
-
|
544
|
-
for i in range(len(keys)):
|
545
|
-
key = keys[i]
|
546
|
-
key_list.append(f"{key}-k")
|
547
|
-
key_list.append(f"{key}-v")
|
548
|
-
if indices is not None:
|
549
|
-
index = indices[i * self.page_size]
|
550
|
-
buf_list.append(self.k_buffer[index : index + self.page_size])
|
551
|
-
buf_list.append(self.v_buffer[index : index + self.page_size])
|
552
|
-
|
553
|
-
return key_list, buf_list, 2
|
476
|
+
element_size_list = [element_size] * len(ptr_list)
|
477
|
+
else:
|
478
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
479
|
+
return ptr_list, element_size_list
|
554
480
|
|
555
481
|
|
556
482
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -726,13 +652,19 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
726
652
|
else:
|
727
653
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
728
654
|
|
729
|
-
def
|
655
|
+
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
730
656
|
if self.layout == "layer_first":
|
731
|
-
|
657
|
+
data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
|
732
658
|
elif self.layout == "page_first":
|
733
|
-
|
659
|
+
data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
|
660
|
+
elif self.layout == "page_first_direct":
|
661
|
+
real_index = index // self.page_size
|
662
|
+
data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
|
734
663
|
else:
|
735
664
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
665
|
+
if flat:
|
666
|
+
data_page = data_page.flatten()
|
667
|
+
return data_page
|
736
668
|
|
737
669
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
738
670
|
return torch.zeros(
|
@@ -762,43 +694,63 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
762
694
|
1,
|
763
695
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
764
696
|
)
|
697
|
+
elif self.layout == "page_first_direct":
|
698
|
+
real_index = index // self.page_size
|
699
|
+
self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape(
|
700
|
+
1,
|
701
|
+
self.layer_num,
|
702
|
+
self.page_size,
|
703
|
+
1,
|
704
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
705
|
+
)
|
765
706
|
else:
|
766
707
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
767
708
|
|
768
|
-
def
|
709
|
+
def get_page_buffer_meta(self, indices):
|
710
|
+
""" "
|
711
|
+
meta data for zero copy
|
712
|
+
"""
|
713
|
+
assert len(indices) % self.page_size == 0
|
769
714
|
ptr_list = []
|
770
|
-
key_list = []
|
771
715
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
772
716
|
indices = indices.tolist()
|
773
|
-
|
774
|
-
|
775
|
-
|
776
|
-
|
777
|
-
|
717
|
+
if self.layout == "layer_first":
|
718
|
+
for index in range(0, len(indices), self.page_size):
|
719
|
+
for layer_id in range(self.layer_num):
|
720
|
+
k_ptr = (
|
721
|
+
kv_buffer_data_ptr
|
722
|
+
+ indices[index]
|
723
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
724
|
+
* self.dtype.itemsize
|
725
|
+
+ layer_id
|
726
|
+
* self.size
|
727
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
728
|
+
* self.dtype.itemsize
|
729
|
+
)
|
730
|
+
ptr_list.append(k_ptr)
|
731
|
+
element_size = (
|
732
|
+
self.dtype.itemsize
|
733
|
+
* self.page_size
|
778
734
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
735
|
+
)
|
736
|
+
element_size_list = [element_size] * len(ptr_list)
|
737
|
+
elif self.layout in ["page_first", "page_first_direct"]:
|
738
|
+
for index in range(0, len(indices), self.page_size):
|
739
|
+
k_ptr = (
|
740
|
+
kv_buffer_data_ptr
|
741
|
+
+ indices[index]
|
742
|
+
* self.layer_num
|
743
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
744
|
+
* self.dtype.itemsize
|
745
|
+
)
|
746
|
+
ptr_list.append(k_ptr)
|
747
|
+
element_size = (
|
748
|
+
self.layer_num
|
779
749
|
* self.dtype.itemsize
|
750
|
+
* self.page_size
|
751
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
780
752
|
)
|
781
|
-
ptr_list
|
782
|
-
|
783
|
-
|
784
|
-
|
785
|
-
self.layer_num
|
786
|
-
* self.dtype.itemsize
|
787
|
-
* self.page_size
|
788
|
-
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
789
|
-
)
|
790
|
-
element_size_list = [element_size] * len(key_list)
|
791
|
-
return key_list, ptr_list, element_size_list
|
792
|
-
|
793
|
-
def get_buffer_with_hash(self, keys, indices=None):
|
794
|
-
assert self.layout == "page_first"
|
795
|
-
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
796
|
-
|
797
|
-
buf_list = []
|
798
|
-
|
799
|
-
if indices is not None:
|
800
|
-
for i in range(len(keys)):
|
801
|
-
index = indices[i * self.page_size]
|
802
|
-
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
803
|
-
|
804
|
-
return keys, buf_list, 1
|
753
|
+
element_size_list = [element_size] * len(ptr_list)
|
754
|
+
else:
|
755
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
756
|
+
return ptr_list, element_size_list
|