sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -3,22 +3,26 @@ import logging
|
|
3
3
|
import threading
|
4
4
|
from enum import IntEnum
|
5
5
|
from functools import wraps
|
6
|
+
from typing import Optional
|
6
7
|
|
7
8
|
import psutil
|
8
9
|
import torch
|
9
10
|
|
10
11
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
-
from sglang.srt.utils import is_npu
|
12
|
+
from sglang.srt.utils import is_npu, is_xpu
|
12
13
|
|
13
14
|
_is_npu = is_npu()
|
14
|
-
|
15
|
+
_is_xpu = is_xpu()
|
16
|
+
if not (_is_npu or _is_xpu):
|
15
17
|
from sgl_kernel.kvcacheio import (
|
16
18
|
transfer_kv_all_layer,
|
19
|
+
transfer_kv_all_layer_direct_lf_pf,
|
17
20
|
transfer_kv_all_layer_lf_pf,
|
18
21
|
transfer_kv_all_layer_mla,
|
19
22
|
transfer_kv_all_layer_mla_lf_pf,
|
20
23
|
transfer_kv_direct,
|
21
24
|
transfer_kv_per_layer,
|
25
|
+
transfer_kv_per_layer_direct_pf_lf,
|
22
26
|
transfer_kv_per_layer_mla,
|
23
27
|
transfer_kv_per_layer_mla_pf_lf,
|
24
28
|
transfer_kv_per_layer_pf_lf,
|
@@ -27,27 +31,13 @@ if not _is_npu:
|
|
27
31
|
logger = logging.getLogger(__name__)
|
28
32
|
|
29
33
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
BACKUP = 4
|
34
|
+
def synchronized(func):
|
35
|
+
@wraps(func)
|
36
|
+
def wrapper(self, *args, **kwargs):
|
37
|
+
with self.lock:
|
38
|
+
return func(self, *args, **kwargs)
|
36
39
|
|
37
|
-
|
38
|
-
def synchronized(debug_only=False):
|
39
|
-
def _decorator(func):
|
40
|
-
@wraps(func)
|
41
|
-
def wrapper(self, *args, **kwargs):
|
42
|
-
if (not debug_only) or self.debug:
|
43
|
-
with self.lock:
|
44
|
-
return func(self, *args, **kwargs)
|
45
|
-
else:
|
46
|
-
return True
|
47
|
-
|
48
|
-
return wrapper
|
49
|
-
|
50
|
-
return _decorator
|
40
|
+
return wrapper
|
51
41
|
|
52
42
|
|
53
43
|
class HostKVCache(abc.ABC):
|
@@ -76,6 +66,7 @@ class HostKVCache(abc.ABC):
|
|
76
66
|
self.size = int(device_pool.size * host_to_device_ratio)
|
77
67
|
# Align the host memory pool size to the page size
|
78
68
|
self.size = self.size - (self.size % self.page_size)
|
69
|
+
self.page_num = self.size // self.page_size
|
79
70
|
self.start_layer = device_pool.start_layer
|
80
71
|
self.end_layer = device_pool.end_layer
|
81
72
|
|
@@ -105,7 +96,6 @@ class HostKVCache(abc.ABC):
|
|
105
96
|
|
106
97
|
# A lock for synchronized operations on memory allocation and state transitions.
|
107
98
|
self.lock = threading.RLock()
|
108
|
-
self.debug = logger.isEnabledFor(logging.DEBUG)
|
109
99
|
self.clear()
|
110
100
|
|
111
101
|
@abc.abstractmethod
|
@@ -135,7 +125,7 @@ class HostKVCache(abc.ABC):
|
|
135
125
|
raise NotImplementedError()
|
136
126
|
|
137
127
|
@abc.abstractmethod
|
138
|
-
def
|
128
|
+
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
139
129
|
"""
|
140
130
|
Get a flat data page from the host memory pool.
|
141
131
|
"""
|
@@ -156,7 +146,7 @@ class HostKVCache(abc.ABC):
|
|
156
146
|
"""
|
157
147
|
raise NotImplementedError()
|
158
148
|
|
159
|
-
@synchronized
|
149
|
+
@synchronized
|
160
150
|
def clear(self):
|
161
151
|
# Initialize memory states and tracking structures.
|
162
152
|
self.mem_state = torch.zeros(
|
@@ -167,8 +157,8 @@ class HostKVCache(abc.ABC):
|
|
167
157
|
def available_size(self):
|
168
158
|
return len(self.free_slots)
|
169
159
|
|
170
|
-
@synchronized
|
171
|
-
def alloc(self, need_size: int) -> torch.Tensor:
|
160
|
+
@synchronized
|
161
|
+
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
|
172
162
|
assert (
|
173
163
|
need_size % self.page_size == 0
|
174
164
|
), "The requested size should be a multiple of the page size."
|
@@ -178,92 +168,13 @@ class HostKVCache(abc.ABC):
|
|
178
168
|
select_index = self.free_slots[:need_size]
|
179
169
|
self.free_slots = self.free_slots[need_size:]
|
180
170
|
|
181
|
-
if self.debug:
|
182
|
-
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
183
|
-
|
184
171
|
return select_index
|
185
172
|
|
186
|
-
@synchronized
|
173
|
+
@synchronized
|
187
174
|
def free(self, indices: torch.Tensor) -> int:
|
188
175
|
self.free_slots = torch.cat([self.free_slots, indices])
|
189
|
-
if self.debug:
|
190
|
-
self.mem_state[indices] = MemoryStateInt.IDLE
|
191
176
|
return len(indices)
|
192
177
|
|
193
|
-
@synchronized(debug_only=True)
|
194
|
-
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
195
|
-
assert len(indices) > 0, "The indices should not be empty"
|
196
|
-
states = self.mem_state[indices]
|
197
|
-
assert (
|
198
|
-
states == states[0]
|
199
|
-
).all(), "The memory slots should have the same state {}".format(states)
|
200
|
-
return MemoryStateInt(states[0].item())
|
201
|
-
|
202
|
-
@synchronized(debug_only=True)
|
203
|
-
def is_reserved(self, indices: torch.Tensor) -> bool:
|
204
|
-
return self.get_state(indices) == MemoryStateInt.RESERVED
|
205
|
-
|
206
|
-
@synchronized(debug_only=True)
|
207
|
-
def is_protected(self, indices: torch.Tensor) -> bool:
|
208
|
-
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
209
|
-
|
210
|
-
@synchronized(debug_only=True)
|
211
|
-
def is_synced(self, indices: torch.Tensor) -> bool:
|
212
|
-
return self.get_state(indices) == MemoryStateInt.SYNCED
|
213
|
-
|
214
|
-
@synchronized(debug_only=True)
|
215
|
-
def is_backup(self, indices: torch.Tensor) -> bool:
|
216
|
-
return self.get_state(indices) == MemoryStateInt.BACKUP
|
217
|
-
|
218
|
-
@synchronized(debug_only=True)
|
219
|
-
def update_backup(self, indices: torch.Tensor):
|
220
|
-
if not self.is_synced(indices):
|
221
|
-
raise ValueError(
|
222
|
-
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
223
|
-
f"Current state: {self.get_state(indices)}"
|
224
|
-
)
|
225
|
-
self.mem_state[indices] = MemoryStateInt.BACKUP
|
226
|
-
|
227
|
-
@synchronized(debug_only=True)
|
228
|
-
def update_prefetch(self, indices: torch.Tensor):
|
229
|
-
if not self.is_reserved(indices):
|
230
|
-
raise ValueError(
|
231
|
-
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
232
|
-
f"Current state: {self.get_state(indices)}"
|
233
|
-
)
|
234
|
-
self.mem_state[indices] = MemoryStateInt.BACKUP
|
235
|
-
|
236
|
-
@synchronized(debug_only=True)
|
237
|
-
def update_synced(self, indices: torch.Tensor):
|
238
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
239
|
-
|
240
|
-
@synchronized(debug_only=True)
|
241
|
-
def protect_write(self, indices: torch.Tensor):
|
242
|
-
if not self.is_reserved(indices):
|
243
|
-
raise ValueError(
|
244
|
-
f"The host memory slots should be RESERVED before write operations. "
|
245
|
-
f"Current state: {self.get_state(indices)}"
|
246
|
-
)
|
247
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
248
|
-
|
249
|
-
@synchronized(debug_only=True)
|
250
|
-
def protect_load(self, indices: torch.Tensor):
|
251
|
-
if not self.is_backup(indices):
|
252
|
-
raise ValueError(
|
253
|
-
f"The host memory slots should be in BACKUP state before load operations. "
|
254
|
-
f"Current state: {self.get_state(indices)}"
|
255
|
-
)
|
256
|
-
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
257
|
-
|
258
|
-
@synchronized(debug_only=True)
|
259
|
-
def complete_io(self, indices: torch.Tensor):
|
260
|
-
if not self.is_protected(indices):
|
261
|
-
raise ValueError(
|
262
|
-
f"The host memory slots should be PROTECTED during I/O operations. "
|
263
|
-
f"Current state: {self.get_state(indices)}"
|
264
|
-
)
|
265
|
-
self.mem_state[indices] = MemoryStateInt.SYNCED
|
266
|
-
|
267
178
|
|
268
179
|
class MHATokenToKVPoolHost(HostKVCache):
|
269
180
|
device_pool: MHATokenToKVPool
|
@@ -315,6 +226,15 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
315
226
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
316
227
|
elif self.layout == "page_first":
|
317
228
|
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
229
|
+
elif self.layout == "page_first_direct":
|
230
|
+
dims = (
|
231
|
+
2,
|
232
|
+
self.page_num,
|
233
|
+
self.layer_num,
|
234
|
+
self.page_size,
|
235
|
+
self.head_num,
|
236
|
+
self.head_dim,
|
237
|
+
)
|
318
238
|
else:
|
319
239
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
320
240
|
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
@@ -368,19 +288,31 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
368
288
|
else:
|
369
289
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
370
290
|
elif io_backend == "direct":
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
291
|
+
if self.layout == "layer_first":
|
292
|
+
transfer_kv_direct(
|
293
|
+
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
294
|
+
dst_layers=[
|
295
|
+
device_pool.k_buffer[layer_id],
|
296
|
+
device_pool.v_buffer[layer_id],
|
297
|
+
],
|
298
|
+
src_indices=host_indices,
|
299
|
+
dst_indices=device_indices,
|
300
|
+
page_size=self.page_size,
|
301
|
+
)
|
302
|
+
elif self.layout == "page_first_direct":
|
303
|
+
transfer_kv_per_layer_direct_pf_lf(
|
304
|
+
src_ptrs=[self.k_buffer, self.v_buffer],
|
305
|
+
dst_ptrs=[
|
306
|
+
device_pool.k_buffer[layer_id],
|
307
|
+
device_pool.v_buffer[layer_id],
|
308
|
+
],
|
309
|
+
src_indices=host_indices,
|
310
|
+
dst_indices=device_indices,
|
311
|
+
layer_id=layer_id,
|
312
|
+
page_size=self.page_size,
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
384
316
|
else:
|
385
317
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
386
318
|
|
@@ -414,26 +346,40 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
414
346
|
else:
|
415
347
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
416
348
|
elif io_backend == "direct":
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
349
|
+
if self.layout == "layer_first":
|
350
|
+
transfer_kv_direct(
|
351
|
+
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
352
|
+
dst_layers=self.k_data_refs + self.v_data_refs,
|
353
|
+
src_indices=device_indices,
|
354
|
+
dst_indices=host_indices,
|
355
|
+
page_size=self.page_size,
|
356
|
+
)
|
357
|
+
elif self.layout == "page_first_direct":
|
358
|
+
transfer_kv_all_layer_direct_lf_pf(
|
359
|
+
src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
|
360
|
+
dst_ptrs=[self.k_buffer, self.v_buffer],
|
361
|
+
src_indices=device_indices,
|
362
|
+
dst_indices=host_indices,
|
363
|
+
page_size=self.page_size,
|
364
|
+
)
|
365
|
+
else:
|
366
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
427
367
|
else:
|
428
368
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
429
369
|
|
430
|
-
def
|
370
|
+
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
431
371
|
if self.layout == "layer_first":
|
432
|
-
|
372
|
+
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
|
433
373
|
elif self.layout == "page_first":
|
434
|
-
|
374
|
+
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
|
375
|
+
elif self.layout == "page_first_direct":
|
376
|
+
real_index = index // self.page_size
|
377
|
+
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
|
435
378
|
else:
|
436
379
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
380
|
+
if flat:
|
381
|
+
data_page = data_page.flatten()
|
382
|
+
return data_page
|
437
383
|
|
438
384
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
439
385
|
return torch.zeros(
|
@@ -460,12 +406,22 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
460
406
|
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
461
407
|
)
|
462
408
|
)
|
409
|
+
elif self.layout == "page_first_direct":
|
410
|
+
real_index = index // self.page_size
|
411
|
+
self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
|
412
|
+
data_page.reshape(
|
413
|
+
2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
|
414
|
+
)
|
415
|
+
)
|
463
416
|
else:
|
464
417
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
465
418
|
|
466
|
-
def
|
419
|
+
def get_page_buffer_meta(self, indices):
|
420
|
+
""" "
|
421
|
+
meta data for zero copy
|
422
|
+
"""
|
423
|
+
assert len(indices) % self.page_size == 0
|
467
424
|
ptr_list = []
|
468
|
-
key_list = []
|
469
425
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
470
426
|
indices = indices.tolist()
|
471
427
|
v_offset = (
|
@@ -475,48 +431,52 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
475
431
|
* self.head_dim
|
476
432
|
* self.dtype.itemsize
|
477
433
|
)
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
434
|
+
if self.layout == "layer_first":
|
435
|
+
for index in range(0, len(indices), self.page_size):
|
436
|
+
for layer_id in range(self.layer_num):
|
437
|
+
k_ptr = (
|
438
|
+
kv_buffer_data_ptr
|
439
|
+
+ indices[index]
|
440
|
+
* self.head_num
|
441
|
+
* self.head_dim
|
442
|
+
* self.dtype.itemsize
|
443
|
+
+ layer_id
|
444
|
+
* self.size
|
445
|
+
* self.head_num
|
446
|
+
* self.head_dim
|
447
|
+
* self.dtype.itemsize
|
448
|
+
)
|
449
|
+
v_ptr = k_ptr + v_offset
|
450
|
+
ptr_list.append(k_ptr)
|
451
|
+
ptr_list.append(v_ptr)
|
452
|
+
element_size = (
|
453
|
+
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
|
454
|
+
)
|
455
|
+
element_size_list = [element_size] * len(ptr_list)
|
456
|
+
elif self.layout in ["page_first", "page_first_direct"]:
|
457
|
+
for index in range(0, len(indices), self.page_size):
|
458
|
+
k_ptr = (
|
459
|
+
kv_buffer_data_ptr
|
460
|
+
+ indices[index]
|
461
|
+
* self.layer_num
|
462
|
+
* self.head_num
|
463
|
+
* self.head_dim
|
464
|
+
* self.dtype.itemsize
|
465
|
+
)
|
466
|
+
v_ptr = k_ptr + v_offset
|
467
|
+
ptr_list.append(k_ptr)
|
468
|
+
ptr_list.append(v_ptr)
|
469
|
+
element_size = (
|
470
|
+
self.layer_num
|
471
|
+
* self.dtype.itemsize
|
472
|
+
* self.page_size
|
483
473
|
* self.head_num
|
484
474
|
* self.head_dim
|
485
|
-
* self.dtype.itemsize
|
486
475
|
)
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
key_list.append(f"{key_}_{local_rank}_k")
|
492
|
-
key_list.append(f"{key_}_{local_rank}_v")
|
493
|
-
element_size = (
|
494
|
-
self.layer_num
|
495
|
-
* self.dtype.itemsize
|
496
|
-
* self.page_size
|
497
|
-
* self.head_num
|
498
|
-
* self.head_dim
|
499
|
-
)
|
500
|
-
element_size_list = [element_size] * len(key_list)
|
501
|
-
return key_list, ptr_list, element_size_list
|
502
|
-
|
503
|
-
def get_buffer_with_hash(self, keys, indices=None):
|
504
|
-
assert self.layout == "page_first"
|
505
|
-
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
506
|
-
|
507
|
-
key_list = []
|
508
|
-
buf_list = []
|
509
|
-
|
510
|
-
for i in range(len(keys)):
|
511
|
-
key = keys[i]
|
512
|
-
key_list.append(f"{key}-k")
|
513
|
-
key_list.append(f"{key}-v")
|
514
|
-
if indices is not None:
|
515
|
-
index = indices[i * self.page_size]
|
516
|
-
buf_list.append(self.k_buffer[index : index + self.page_size])
|
517
|
-
buf_list.append(self.v_buffer[index : index + self.page_size])
|
518
|
-
|
519
|
-
return key_list, buf_list, 2
|
476
|
+
element_size_list = [element_size] * len(ptr_list)
|
477
|
+
else:
|
478
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
479
|
+
return ptr_list, element_size_list
|
520
480
|
|
521
481
|
|
522
482
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -578,6 +538,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
578
538
|
1,
|
579
539
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
580
540
|
)
|
541
|
+
elif self.layout == "page_first_direct":
|
542
|
+
dims = (
|
543
|
+
self.page_num,
|
544
|
+
self.layer_num,
|
545
|
+
self.page_size,
|
546
|
+
1,
|
547
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
548
|
+
)
|
581
549
|
else:
|
582
550
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
583
551
|
self.token_stride_size = (
|
@@ -617,16 +585,25 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
617
585
|
else:
|
618
586
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
619
587
|
elif io_backend == "direct":
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
588
|
+
if self.layout == "layer_first":
|
589
|
+
transfer_kv_direct(
|
590
|
+
src_layers=[self.kv_buffer[layer_id]],
|
591
|
+
dst_layers=[device_pool.kv_buffer[layer_id]],
|
592
|
+
src_indices=host_indices,
|
593
|
+
dst_indices=device_indices,
|
594
|
+
page_size=self.page_size,
|
595
|
+
)
|
596
|
+
elif self.layout == "page_first_direct":
|
597
|
+
transfer_kv_per_layer_direct_pf_lf(
|
598
|
+
src_ptrs=[self.kv_buffer],
|
599
|
+
dst_ptrs=[device_pool.kv_buffer[layer_id]],
|
600
|
+
src_indices=host_indices,
|
601
|
+
dst_indices=device_indices,
|
602
|
+
layer_id=layer_id,
|
603
|
+
page_size=self.page_size,
|
604
|
+
)
|
605
|
+
else:
|
606
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
630
607
|
|
631
608
|
def backup_from_device_all_layer(
|
632
609
|
self, device_pool, host_indices, device_indices, io_backend
|
@@ -654,26 +631,40 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
654
631
|
else:
|
655
632
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
656
633
|
elif io_backend == "direct":
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
634
|
+
if self.layout == "layer_first":
|
635
|
+
transfer_kv_direct(
|
636
|
+
src_layers=device_pool.kv_buffer,
|
637
|
+
dst_layers=self.data_refs,
|
638
|
+
src_indices=device_indices,
|
639
|
+
dst_indices=host_indices,
|
640
|
+
page_size=self.page_size,
|
641
|
+
)
|
642
|
+
elif self.layout == "page_first_direct":
|
643
|
+
transfer_kv_all_layer_direct_lf_pf(
|
644
|
+
src_ptrs=device_pool.kv_buffer,
|
645
|
+
dst_ptrs=[self.kv_buffer],
|
646
|
+
src_indices=device_indices,
|
647
|
+
dst_indices=host_indices,
|
648
|
+
page_size=self.page_size,
|
649
|
+
)
|
650
|
+
else:
|
651
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
667
652
|
else:
|
668
653
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
669
654
|
|
670
|
-
def
|
655
|
+
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
|
671
656
|
if self.layout == "layer_first":
|
672
|
-
|
657
|
+
data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
|
673
658
|
elif self.layout == "page_first":
|
674
|
-
|
659
|
+
data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
|
660
|
+
elif self.layout == "page_first_direct":
|
661
|
+
real_index = index // self.page_size
|
662
|
+
data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
|
675
663
|
else:
|
676
664
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
665
|
+
if flat:
|
666
|
+
data_page = data_page.flatten()
|
667
|
+
return data_page
|
677
668
|
|
678
669
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
679
670
|
return torch.zeros(
|
@@ -703,43 +694,63 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
703
694
|
1,
|
704
695
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
705
696
|
)
|
697
|
+
elif self.layout == "page_first_direct":
|
698
|
+
real_index = index // self.page_size
|
699
|
+
self.kv_buffer[real_index : real_index + 1, :, :, :, :] = data_page.reshape(
|
700
|
+
1,
|
701
|
+
self.layer_num,
|
702
|
+
self.page_size,
|
703
|
+
1,
|
704
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
705
|
+
)
|
706
706
|
else:
|
707
707
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
708
708
|
|
709
|
-
def
|
709
|
+
def get_page_buffer_meta(self, indices):
|
710
|
+
""" "
|
711
|
+
meta data for zero copy
|
712
|
+
"""
|
713
|
+
assert len(indices) % self.page_size == 0
|
710
714
|
ptr_list = []
|
711
|
-
key_list = []
|
712
715
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
713
716
|
indices = indices.tolist()
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
717
|
+
if self.layout == "layer_first":
|
718
|
+
for index in range(0, len(indices), self.page_size):
|
719
|
+
for layer_id in range(self.layer_num):
|
720
|
+
k_ptr = (
|
721
|
+
kv_buffer_data_ptr
|
722
|
+
+ indices[index]
|
723
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
724
|
+
* self.dtype.itemsize
|
725
|
+
+ layer_id
|
726
|
+
* self.size
|
727
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
728
|
+
* self.dtype.itemsize
|
729
|
+
)
|
730
|
+
ptr_list.append(k_ptr)
|
731
|
+
element_size = (
|
732
|
+
self.dtype.itemsize
|
733
|
+
* self.page_size
|
719
734
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
735
|
+
)
|
736
|
+
element_size_list = [element_size] * len(ptr_list)
|
737
|
+
elif self.layout in ["page_first", "page_first_direct"]:
|
738
|
+
for index in range(0, len(indices), self.page_size):
|
739
|
+
k_ptr = (
|
740
|
+
kv_buffer_data_ptr
|
741
|
+
+ indices[index]
|
742
|
+
* self.layer_num
|
743
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
744
|
+
* self.dtype.itemsize
|
745
|
+
)
|
746
|
+
ptr_list.append(k_ptr)
|
747
|
+
element_size = (
|
748
|
+
self.layer_num
|
720
749
|
* self.dtype.itemsize
|
750
|
+
* self.page_size
|
751
|
+
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
721
752
|
)
|
722
|
-
ptr_list
|
723
|
-
|
724
|
-
|
725
|
-
|
726
|
-
self.layer_num
|
727
|
-
* self.dtype.itemsize
|
728
|
-
* self.page_size
|
729
|
-
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
730
|
-
)
|
731
|
-
element_size_list = [element_size] * len(key_list)
|
732
|
-
return key_list, ptr_list, element_size_list
|
733
|
-
|
734
|
-
def get_buffer_with_hash(self, keys, indices=None):
|
735
|
-
assert self.layout == "page_first"
|
736
|
-
assert indices is None or (len(keys) == (len(indices) // self.page_size))
|
737
|
-
|
738
|
-
buf_list = []
|
739
|
-
|
740
|
-
if indices is not None:
|
741
|
-
for i in range(len(keys)):
|
742
|
-
index = indices[i * self.page_size]
|
743
|
-
buf_list.append(self.kv_buffer[index : index + self.page_size])
|
744
|
-
|
745
|
-
return keys, buf_list, 1
|
753
|
+
element_size_list = [element_size] * len(ptr_list)
|
754
|
+
else:
|
755
|
+
raise ValueError(f"Unsupported layout: {self.layout}")
|
756
|
+
return ptr_list, element_size_list
|