sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import TYPE_CHECKING, List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
|
|
33
33
|
get_tensor_model_parallel_world_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_dp_rank,
|
36
37
|
get_attention_tp_rank,
|
37
38
|
get_attention_tp_size,
|
38
39
|
is_dp_attention_enabled,
|
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class LayerLoadingEvent:
|
47
|
+
def __init__(self, num_layers: int):
|
48
|
+
self._num_layers = num_layers
|
49
|
+
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
50
|
+
self.start_event = torch.cuda.Event() # start event on controller stream
|
51
|
+
|
52
|
+
def complete(self, layer_index: int):
|
53
|
+
assert 0 <= layer_index < self._num_layers
|
54
|
+
self.load_events[layer_index].record()
|
55
|
+
|
56
|
+
def wait(self, layer_index: int):
|
57
|
+
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
58
|
+
|
59
|
+
@property
|
60
|
+
def finish_event(self):
|
61
|
+
return self.load_events[-1]
|
62
|
+
|
63
|
+
|
45
64
|
class LayerDoneCounter:
|
46
|
-
def __init__(self, num_layers):
|
65
|
+
def __init__(self, num_layers: int):
|
47
66
|
self.num_layers = num_layers
|
48
67
|
# extra producer and consumer counters for overlap mode
|
49
68
|
self.num_counters = 3
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
self.consumer_index = 0
|
54
|
-
|
55
|
-
def next_producer(self):
|
56
|
-
return (self.producer_index + 1) % self.num_counters
|
69
|
+
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
70
|
+
self.producer_index = -1
|
71
|
+
self.consumer_index = -1
|
57
72
|
|
58
73
|
def update_producer(self):
|
59
|
-
self.producer_index = self.
|
74
|
+
self.producer_index = (self.producer_index + 1) % self.num_counters
|
75
|
+
assert self.events[
|
76
|
+
self.producer_index
|
77
|
+
].finish_event.query(), (
|
78
|
+
"Producer finish event should be ready before being reused."
|
79
|
+
)
|
60
80
|
return self.producer_index
|
61
81
|
|
62
|
-
def set_consumer(self, index):
|
82
|
+
def set_consumer(self, index: int):
|
63
83
|
self.consumer_index = index
|
64
84
|
|
65
|
-
def
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
def wait_until(self, threshold):
|
71
|
-
with self.conditions[self.consumer_index]:
|
72
|
-
while self.counters[self.consumer_index] <= threshold:
|
73
|
-
self.conditions[self.consumer_index].wait()
|
85
|
+
def wait_until(self, threshold: int):
|
86
|
+
if self.consumer_index < 0:
|
87
|
+
return
|
88
|
+
self.events[self.consumer_index].wait(threshold)
|
74
89
|
|
75
90
|
def reset(self):
|
76
|
-
|
77
|
-
|
91
|
+
self.producer_index = -1
|
92
|
+
self.consumer_index = -1
|
78
93
|
|
79
94
|
|
80
95
|
class CacheOperation:
|
@@ -98,36 +113,30 @@ class CacheOperation:
|
|
98
113
|
# default priority is the order of creation
|
99
114
|
self.priority = priority if priority is not None else self.id
|
100
115
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
device_indices=self.device_indices[i : i + chunk_size],
|
120
|
-
node_id=0,
|
121
|
-
)
|
122
|
-
)
|
123
|
-
# Inherit the node_ids on the final chunk
|
124
|
-
if split_ops:
|
125
|
-
split_ops[-1].node_ids = self.node_ids
|
116
|
+
@staticmethod
|
117
|
+
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
118
|
+
assert len(ops) > 0
|
119
|
+
if len(ops) == 1:
|
120
|
+
return ops[0]
|
121
|
+
|
122
|
+
host_indices = torch.cat([op.host_indices for op in ops])
|
123
|
+
device_indices = torch.cat([op.device_indices for op in ops])
|
124
|
+
node_ids = []
|
125
|
+
priority = min(op.priority for op in ops)
|
126
|
+
for op in ops:
|
127
|
+
node_ids.extend(op.node_ids)
|
128
|
+
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
129
|
+
merged_op.node_ids = node_ids
|
130
|
+
return merged_op
|
131
|
+
|
132
|
+
def __lt__(self, other: CacheOperation):
|
133
|
+
return self.priority < other.priority
|
126
134
|
|
127
|
-
return split_ops
|
128
135
|
|
129
|
-
|
130
|
-
|
136
|
+
class HiCacheAck(NamedTuple):
|
137
|
+
start_event: torch.cuda.Event
|
138
|
+
finish_event: torch.cuda.Event
|
139
|
+
node_ids: List[int]
|
131
140
|
|
132
141
|
|
133
142
|
class TransferBuffer:
|
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
|
|
206
215
|
):
|
207
216
|
self.request_id = request_id
|
208
217
|
|
209
|
-
self._done_flag = False
|
210
218
|
self._lock = threading.Lock()
|
211
|
-
|
219
|
+
self._terminated_flag = False
|
212
220
|
self.start_time = time.monotonic()
|
213
221
|
|
214
222
|
super().__init__(host_indices, token_ids, last_hash)
|
215
223
|
|
216
224
|
def increment(self, num_tokens: int):
|
217
225
|
with self._lock:
|
218
|
-
if self.
|
226
|
+
if self._terminated_flag:
|
219
227
|
return False
|
220
228
|
self.completed_tokens += num_tokens
|
221
229
|
return True
|
222
230
|
|
223
|
-
def
|
231
|
+
def mark_terminate(self):
|
224
232
|
with self._lock:
|
225
|
-
self.
|
233
|
+
self._terminated_flag = True
|
226
234
|
|
227
|
-
def
|
228
|
-
return self.
|
235
|
+
def is_terminated(self) -> bool:
|
236
|
+
return self._terminated_flag
|
229
237
|
|
230
238
|
|
231
239
|
class HiCacheController:
|
@@ -236,13 +244,13 @@ class HiCacheController:
|
|
236
244
|
mem_pool_host: HostKVCache,
|
237
245
|
page_size: int,
|
238
246
|
tp_group: torch.distributed.ProcessGroup,
|
239
|
-
load_cache_event: threading.Event
|
247
|
+
load_cache_event: threading.Event,
|
240
248
|
write_policy: str = "write_through_selective",
|
241
249
|
io_backend: str = "",
|
242
250
|
storage_backend: Optional[str] = None,
|
243
251
|
prefetch_threshold: int = 256,
|
244
252
|
model_name: Optional[str] = None,
|
245
|
-
storage_backend_extra_config: Optional[
|
253
|
+
storage_backend_extra_config: Optional[dict] = None,
|
246
254
|
):
|
247
255
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
248
256
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -267,43 +275,17 @@ class HiCacheController:
|
|
267
275
|
and self.storage_config.tp_rank != 0
|
268
276
|
)
|
269
277
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
self.storage_backend = HiCacheFile(self.storage_config)
|
274
|
-
elif storage_backend == "nixl":
|
275
|
-
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
278
|
+
# Use storage backend factory for dynamic backend creation
|
279
|
+
from sglang.srt.mem_cache.storage import StorageBackendFactory
|
276
280
|
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
MooncakeStore,
|
281
|
-
)
|
282
|
-
|
283
|
-
self.storage_backend = MooncakeStore(self.storage_config)
|
284
|
-
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
285
|
-
assert self.mem_pool_host.layout == "page_first"
|
286
|
-
elif storage_backend == "hf3fs":
|
287
|
-
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
288
|
-
HiCacheHF3FS,
|
281
|
+
try:
|
282
|
+
self.storage_backend = StorageBackendFactory.create_backend(
|
283
|
+
storage_backend, self.storage_config, self.mem_pool_host
|
289
284
|
)
|
285
|
+
except ValueError as e:
|
286
|
+
raise ValueError(f"Failed to create storage backend: {e}") from e
|
290
287
|
|
291
|
-
|
292
|
-
bytes_per_page = (
|
293
|
-
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
294
|
-
)
|
295
|
-
elif self.mem_pool_host.layout == "layer_first":
|
296
|
-
bytes_per_page = (
|
297
|
-
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
298
|
-
)
|
299
|
-
dtype = mem_pool_host.dtype
|
300
|
-
self.storage_backend = HiCacheHF3FS.from_env_config(
|
301
|
-
bytes_per_page, dtype, self.storage_config
|
302
|
-
)
|
303
|
-
else:
|
304
|
-
raise NotImplementedError(
|
305
|
-
f"Unsupported storage backend: {storage_backend}"
|
306
|
-
)
|
288
|
+
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
307
289
|
|
308
290
|
self.enable_storage = True
|
309
291
|
# todo: threshold policy for prefetching
|
@@ -324,8 +306,17 @@ class HiCacheController:
|
|
324
306
|
group_ranks, backend="gloo"
|
325
307
|
)
|
326
308
|
|
327
|
-
|
328
|
-
|
309
|
+
# Select the get and set functions
|
310
|
+
self.page_get_func = self._generic_page_get
|
311
|
+
self.page_set_func = self._generic_page_set
|
312
|
+
|
313
|
+
if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
|
314
|
+
self.page_get_func = self._page_get_zero_copy
|
315
|
+
self.page_set_func = self._page_set_zero_copy
|
316
|
+
|
317
|
+
self.device = self.mem_pool_device.device
|
318
|
+
self.layer_num = self.mem_pool_device.layer_num
|
319
|
+
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
329
320
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
330
321
|
|
331
322
|
if write_policy not in [
|
@@ -335,11 +326,11 @@ class HiCacheController:
|
|
335
326
|
]:
|
336
327
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
337
328
|
|
338
|
-
self.write_queue = PriorityQueue()
|
339
|
-
self.load_queue =
|
340
|
-
|
341
|
-
self.
|
342
|
-
self.
|
329
|
+
# self.write_queue = PriorityQueue[CacheOperation]()
|
330
|
+
self.load_queue: List[CacheOperation] = []
|
331
|
+
self.write_queue: List[CacheOperation] = []
|
332
|
+
self.ack_load_queue: List[HiCacheAck] = []
|
333
|
+
self.ack_write_queue: List[HiCacheAck] = []
|
343
334
|
|
344
335
|
self.stop_event = threading.Event()
|
345
336
|
self.write_buffer = TransferBuffer(self.stop_event)
|
@@ -350,16 +341,6 @@ class HiCacheController:
|
|
350
341
|
self.write_stream = torch.cuda.Stream()
|
351
342
|
self.load_stream = torch.cuda.Stream()
|
352
343
|
|
353
|
-
self.write_thread = threading.Thread(
|
354
|
-
target=self.write_thread_func_direct, daemon=True
|
355
|
-
)
|
356
|
-
self.load_thread = threading.Thread(
|
357
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
358
|
-
)
|
359
|
-
|
360
|
-
self.write_thread.start()
|
361
|
-
self.load_thread.start()
|
362
|
-
|
363
344
|
if self.enable_storage:
|
364
345
|
self.prefetch_thread = threading.Thread(
|
365
346
|
target=self.prefetch_thread_func, daemon=True
|
@@ -380,48 +361,39 @@ class HiCacheController:
|
|
380
361
|
def _generate_storage_config(
|
381
362
|
self,
|
382
363
|
model_name: Optional[str] = None,
|
383
|
-
storage_backend_extra_config: Optional[
|
364
|
+
storage_backend_extra_config: Optional[dict] = None,
|
384
365
|
):
|
385
366
|
|
386
367
|
if is_dp_attention_enabled():
|
387
368
|
self.tp_rank = get_attention_tp_rank()
|
388
369
|
self.tp_size = get_attention_tp_size()
|
370
|
+
self.dp_rank = get_attention_dp_rank()
|
389
371
|
else:
|
390
372
|
self.tp_rank = get_tensor_model_parallel_rank()
|
391
373
|
self.tp_size = get_tensor_model_parallel_world_size()
|
374
|
+
self.dp_rank = 0
|
392
375
|
|
393
376
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
394
377
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
395
378
|
|
396
|
-
# Parse extra config JSON if provided
|
397
|
-
extra_config = None
|
398
|
-
if storage_backend_extra_config:
|
399
|
-
try:
|
400
|
-
import json
|
401
|
-
|
402
|
-
extra_config = json.loads(storage_backend_extra_config)
|
403
|
-
except Exception as e:
|
404
|
-
logger.error(f"Invalid backend extra config JSON: {e}")
|
405
|
-
|
406
379
|
return HiCacheStorageConfig(
|
407
380
|
tp_rank=self.tp_rank,
|
408
381
|
tp_size=self.tp_size,
|
409
382
|
is_mla_model=is_mla_backend,
|
383
|
+
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
410
384
|
model_name=model_name,
|
411
|
-
extra_config=
|
385
|
+
extra_config=storage_backend_extra_config,
|
412
386
|
)
|
413
387
|
|
414
388
|
def reset(self):
|
415
389
|
self.stop_event.set()
|
416
|
-
self.write_thread.join()
|
417
|
-
self.load_thread.join()
|
418
390
|
|
419
|
-
self.write_queue.
|
420
|
-
self.load_queue.
|
391
|
+
self.write_queue.clear()
|
392
|
+
self.load_queue.clear()
|
421
393
|
self.write_buffer.clear()
|
422
394
|
self.load_buffer.clear()
|
423
|
-
self.ack_write_queue.
|
424
|
-
self.ack_load_queue.
|
395
|
+
self.ack_write_queue.clear()
|
396
|
+
self.ack_load_queue.clear()
|
425
397
|
if self.enable_storage:
|
426
398
|
self.prefetch_thread.join()
|
427
399
|
self.backup_thread.join()
|
@@ -430,15 +402,7 @@ class HiCacheController:
|
|
430
402
|
self.prefetch_revoke_queue.queue.clear()
|
431
403
|
self.ack_backup_queue.queue.clear()
|
432
404
|
|
433
|
-
self.write_thread = threading.Thread(
|
434
|
-
target=self.write_thread_func_direct, daemon=True
|
435
|
-
)
|
436
|
-
self.load_thread = threading.Thread(
|
437
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
438
|
-
)
|
439
405
|
self.stop_event.clear()
|
440
|
-
self.write_thread.start()
|
441
|
-
self.load_thread.start()
|
442
406
|
|
443
407
|
if self.enable_storage:
|
444
408
|
self.prefetch_thread = threading.Thread(
|
@@ -454,7 +418,7 @@ class HiCacheController:
|
|
454
418
|
self,
|
455
419
|
device_indices: torch.Tensor,
|
456
420
|
priority: Optional[int] = None,
|
457
|
-
node_id: int =
|
421
|
+
node_id: int = -1,
|
458
422
|
) -> Optional[torch.Tensor]:
|
459
423
|
"""
|
460
424
|
Back up KV caches from device memory to host memory.
|
@@ -462,18 +426,45 @@ class HiCacheController:
|
|
462
426
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
463
427
|
if host_indices is None:
|
464
428
|
return None
|
465
|
-
self.
|
466
|
-
torch.cuda.current_stream().synchronize()
|
467
|
-
self.write_queue.put(
|
429
|
+
self.write_queue.append(
|
468
430
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
469
431
|
)
|
432
|
+
self.start_writing()
|
470
433
|
return host_indices
|
471
434
|
|
435
|
+
def start_writing(self) -> None:
|
436
|
+
if len(self.write_queue) == 0:
|
437
|
+
return
|
438
|
+
|
439
|
+
op = CacheOperation.merge_ops(self.write_queue)
|
440
|
+
host_indices, device_indices = self.move_indices(op)
|
441
|
+
self.write_queue.clear()
|
442
|
+
|
443
|
+
start_event = torch.cuda.Event()
|
444
|
+
finish_event = torch.cuda.Event()
|
445
|
+
|
446
|
+
start_event.record()
|
447
|
+
with torch.cuda.stream(self.write_stream):
|
448
|
+
start_event.wait(self.write_stream)
|
449
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
450
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
451
|
+
)
|
452
|
+
finish_event.record()
|
453
|
+
# NOTE: We must save the host indices and device indices here,
|
454
|
+
# this is because we need to guarantee that these tensors are
|
455
|
+
# still alive when the write stream is executing.
|
456
|
+
if host_indices.is_cuda:
|
457
|
+
host_indices.record_stream(self.write_stream)
|
458
|
+
if device_indices.is_cuda:
|
459
|
+
device_indices.record_stream(self.write_stream)
|
460
|
+
|
461
|
+
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
462
|
+
|
472
463
|
def load(
|
473
464
|
self,
|
474
465
|
host_indices: torch.Tensor,
|
475
466
|
priority: Optional[int] = None,
|
476
|
-
node_id: int =
|
467
|
+
node_id: int = -1,
|
477
468
|
) -> Optional[torch.Tensor]:
|
478
469
|
"""
|
479
470
|
Load KV caches from host memory to device memory.
|
@@ -481,77 +472,42 @@ class HiCacheController:
|
|
481
472
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
482
473
|
if device_indices is None:
|
483
474
|
return None
|
484
|
-
self.
|
485
|
-
# to ensure the device indices are ready before accessed by another CUDA stream
|
486
|
-
torch.cuda.current_stream().synchronize()
|
487
|
-
self.load_queue.put(
|
475
|
+
self.load_queue.append(
|
488
476
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
489
477
|
)
|
490
478
|
return device_indices
|
491
479
|
|
492
|
-
def move_indices(self,
|
480
|
+
def move_indices(self, op: CacheOperation):
|
481
|
+
host_indices, device_indices = op.host_indices, op.device_indices
|
493
482
|
# move indices to GPU if using kernels, to host if using direct indexing
|
494
483
|
if self.io_backend == "kernel":
|
495
|
-
|
484
|
+
if not host_indices.is_cuda:
|
485
|
+
host_indices = host_indices.to(self.device, non_blocking=True)
|
486
|
+
return host_indices, device_indices
|
496
487
|
elif self.io_backend == "direct":
|
497
|
-
|
498
|
-
|
499
|
-
|
488
|
+
if self.mem_pool_host.layout == "layer_first":
|
489
|
+
device_indices = device_indices.cpu()
|
490
|
+
host_indices, idx = host_indices.sort()
|
491
|
+
return host_indices, device_indices.index_select(0, idx)
|
492
|
+
elif self.mem_pool_host.layout == "page_first_direct":
|
493
|
+
return host_indices, device_indices.cpu()
|
500
494
|
else:
|
501
495
|
raise ValueError(f"Unsupported io backend")
|
502
496
|
|
503
|
-
def
|
504
|
-
|
505
|
-
|
506
|
-
"""
|
507
|
-
torch.cuda.set_stream(self.write_stream)
|
508
|
-
while not self.stop_event.is_set():
|
509
|
-
try:
|
510
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
511
|
-
host_indices, device_indices = self.move_indices(
|
512
|
-
operation.host_indices, operation.device_indices
|
513
|
-
)
|
514
|
-
self.mem_pool_host.backup_from_device_all_layer(
|
515
|
-
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
516
|
-
)
|
517
|
-
self.write_stream.synchronize()
|
518
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
519
|
-
for node_id in operation.node_ids:
|
520
|
-
if node_id != 0:
|
521
|
-
self.ack_write_queue.put(node_id)
|
522
|
-
except Empty:
|
523
|
-
continue
|
524
|
-
except Exception as e:
|
525
|
-
logger.error(e)
|
497
|
+
def start_loading(self) -> int:
|
498
|
+
if len(self.load_queue) == 0:
|
499
|
+
return -1
|
526
500
|
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
self.load_cache_event.wait(timeout=1)
|
534
|
-
if not self.load_cache_event.is_set():
|
535
|
-
continue
|
536
|
-
self.load_cache_event.clear()
|
537
|
-
self.layer_done_counter.update_producer()
|
538
|
-
|
539
|
-
batch_operation = None
|
540
|
-
while self.load_queue.qsize() > 0:
|
541
|
-
op = self.load_queue.get(block=True)
|
542
|
-
if batch_operation is None:
|
543
|
-
batch_operation = op
|
544
|
-
else:
|
545
|
-
batch_operation.merge(op)
|
546
|
-
if batch_operation is None:
|
547
|
-
continue
|
501
|
+
producer_id = self.layer_done_counter.update_producer()
|
502
|
+
op = CacheOperation.merge_ops(self.load_queue)
|
503
|
+
host_indices, device_indices = self.move_indices(op)
|
504
|
+
self.load_queue.clear()
|
505
|
+
producer_event = self.layer_done_counter.events[producer_id]
|
506
|
+
producer_event.start_event.record()
|
548
507
|
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
batch_operation.host_indices, batch_operation.device_indices
|
553
|
-
)
|
554
|
-
for i in range(self.mem_pool_host.layer_num):
|
508
|
+
with torch.cuda.stream(self.load_stream):
|
509
|
+
producer_event.start_event.wait(self.load_stream)
|
510
|
+
for i in range(self.layer_num):
|
555
511
|
self.mem_pool_host.load_to_device_per_layer(
|
556
512
|
self.mem_pool_device,
|
557
513
|
host_indices,
|
@@ -559,37 +515,34 @@ class HiCacheController:
|
|
559
515
|
i,
|
560
516
|
self.io_backend,
|
561
517
|
)
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
self
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
576
|
-
return len(device_indices)
|
577
|
-
else:
|
578
|
-
raise ValueError(
|
579
|
-
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
518
|
+
producer_event.complete(i)
|
519
|
+
# NOTE: We must save the host indices and device indices here,
|
520
|
+
# this is because we need to guarantee that these tensors are
|
521
|
+
# still alive when the load stream is executing.
|
522
|
+
if host_indices.is_cuda:
|
523
|
+
host_indices.record_stream(self.load_stream)
|
524
|
+
if device_indices.is_cuda:
|
525
|
+
device_indices.record_stream(self.load_stream)
|
526
|
+
|
527
|
+
self.ack_load_queue.append(
|
528
|
+
HiCacheAck(
|
529
|
+
start_event=producer_event.start_event,
|
530
|
+
finish_event=producer_event.finish_event,
|
531
|
+
node_ids=op.node_ids,
|
580
532
|
)
|
533
|
+
)
|
534
|
+
return producer_id
|
535
|
+
|
536
|
+
def evict_device(self, device_indices: torch.Tensor) -> int:
|
537
|
+
self.mem_pool_device_allocator.free(device_indices)
|
538
|
+
return len(device_indices)
|
581
539
|
|
582
540
|
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
583
541
|
if not backup_only:
|
584
542
|
raise ValueError("Other eviction policies are not supported yet.")
|
585
543
|
|
586
|
-
|
587
|
-
|
588
|
-
return len(host_indices)
|
589
|
-
else:
|
590
|
-
raise ValueError(
|
591
|
-
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
592
|
-
)
|
544
|
+
self.mem_pool_host.free(host_indices)
|
545
|
+
return len(host_indices)
|
593
546
|
|
594
547
|
def prefetch(
|
595
548
|
self,
|
@@ -608,48 +561,33 @@ class HiCacheController:
|
|
608
561
|
return operation
|
609
562
|
|
610
563
|
def terminate_prefetch(self, operation):
|
611
|
-
operation.
|
564
|
+
operation.mark_terminate()
|
612
565
|
return operation.completed_tokens, operation.hash_value
|
613
566
|
|
614
567
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
621
|
-
hash_values, host_indices
|
622
|
-
)
|
623
|
-
page_data = self.storage_backend.batch_get(hashes, dsts)
|
624
|
-
if page_data:
|
625
|
-
operation.increment(self.page_size * len(hashes))
|
626
|
-
else:
|
627
|
-
logger.warning(
|
628
|
-
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
629
|
-
)
|
568
|
+
if host_indices.numel() == 0:
|
569
|
+
return
|
570
|
+
pages = host_indices.split(self.mem_pool_host.page_size)
|
571
|
+
for page in pages:
|
572
|
+
self.host_mem_release_queue.put(page)
|
630
573
|
|
631
|
-
def
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
)
|
642
|
-
if get_result != len(hash_values):
|
643
|
-
logger.warning(
|
644
|
-
f"Prefetch operation {operation.request_id} failed or partially failed."
|
645
|
-
)
|
646
|
-
if get_result != 0:
|
647
|
-
operation.increment(get_result * self.page_size)
|
574
|
+
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
575
|
+
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
576
|
+
inc = 0
|
577
|
+
for i in range(len(hash_values)):
|
578
|
+
if not results[i]:
|
579
|
+
logger.warning(
|
580
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
581
|
+
)
|
582
|
+
break
|
583
|
+
inc += self.page_size
|
584
|
+
operation.increment(inc)
|
648
585
|
|
586
|
+
# todo: deprecate
|
649
587
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
650
|
-
dummy_page_dst = [
|
651
|
-
hash_values
|
652
|
-
|
588
|
+
dummy_page_dst = [
|
589
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
590
|
+
]
|
653
591
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
654
592
|
if page_data is None:
|
655
593
|
return
|
@@ -659,26 +597,16 @@ class HiCacheController:
|
|
659
597
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
660
598
|
)
|
661
599
|
break
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
600
|
+
# Must set the data before increasing the completed tokens.
|
601
|
+
# Otherwise this page may be read before being set.
|
602
|
+
self.mem_pool_host.set_from_flat_data_page(
|
603
|
+
host_indices[i * self.page_size],
|
604
|
+
page_data[i],
|
605
|
+
)
|
606
|
+
if not operation.increment(self.page_size):
|
607
|
+
break # Operation terminated by controller
|
669
608
|
|
670
609
|
def _page_transfer(self, operation):
|
671
|
-
# Select the get function and batch size
|
672
|
-
if self.storage_backend_type == "mooncake":
|
673
|
-
get_func = self._mooncake_page_get
|
674
|
-
elif (
|
675
|
-
self.storage_backend_type == "hf3fs"
|
676
|
-
and self.mem_pool_host.layout == "page_first"
|
677
|
-
):
|
678
|
-
get_func = self._3fs_zero_copy_page_get
|
679
|
-
else:
|
680
|
-
get_func = self._generic_page_get
|
681
|
-
|
682
610
|
# Transfer batch by batch
|
683
611
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
684
612
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -687,12 +615,13 @@ class HiCacheController:
|
|
687
615
|
]
|
688
616
|
prev_completed_tokens = operation.completed_tokens
|
689
617
|
# Get one batch token, and update the completed_tokens if succeed
|
690
|
-
|
618
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
691
619
|
# Check termination
|
692
620
|
if (
|
693
621
|
operation.completed_tokens
|
694
622
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
695
623
|
):
|
624
|
+
operation.mark_terminate()
|
696
625
|
break # Some operations fail or operation terminated by controller
|
697
626
|
# release pre-allocated memory
|
698
627
|
self.append_host_mem_release(
|
@@ -813,47 +742,19 @@ class HiCacheController:
|
|
813
742
|
self.backup_queue.put(operation)
|
814
743
|
return operation.id
|
815
744
|
|
816
|
-
#
|
745
|
+
# todo: deprecate
|
817
746
|
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
818
747
|
data = [
|
819
|
-
self.mem_pool_host.
|
748
|
+
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
820
749
|
for i in range(len(hash_values))
|
821
750
|
]
|
822
751
|
return self.storage_backend.batch_set(hash_values, data)
|
823
752
|
|
824
|
-
|
825
|
-
|
826
|
-
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
827
|
-
hash_values,
|
828
|
-
host_indices,
|
829
|
-
self.storage_config.tp_rank,
|
830
|
-
)
|
831
|
-
success = self.storage_backend.batch_set(
|
832
|
-
key_strs,
|
833
|
-
target_location=buffer_ptrs,
|
834
|
-
target_sizes=buffer_sizes,
|
835
|
-
)
|
836
|
-
return success
|
837
|
-
|
838
|
-
# zero copy
|
839
|
-
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
840
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
841
|
-
hash_values, host_indices
|
842
|
-
)
|
843
|
-
return self.storage_backend.batch_set(hashes, dsts)
|
753
|
+
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
754
|
+
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
844
755
|
|
845
756
|
# Backup batch by batch
|
846
757
|
def _page_backup(self, operation):
|
847
|
-
# Select the set function and batch size
|
848
|
-
if self.storage_backend_type == "mooncake":
|
849
|
-
backup_set_func = self._mooncake_page_set
|
850
|
-
elif (
|
851
|
-
self.storage_backend_type == "hf3fs"
|
852
|
-
and self.mem_pool_host.layout == "page_first"
|
853
|
-
):
|
854
|
-
backup_set_func = self._3fs_zero_copy_page_set
|
855
|
-
else:
|
856
|
-
backup_set_func = self._generic_page_set
|
857
758
|
# Backup batch by batch
|
858
759
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
859
760
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -862,7 +763,7 @@ class HiCacheController:
|
|
862
763
|
]
|
863
764
|
# Set one batch token, and record if success.
|
864
765
|
# todo: allow partial success
|
865
|
-
success =
|
766
|
+
success = self.page_set_func(batch_hashes, batch_host_indices)
|
866
767
|
if not success:
|
867
768
|
logger.warning(
|
868
769
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
@@ -882,7 +783,7 @@ class HiCacheController:
|
|
882
783
|
|
883
784
|
if not self.backup_skip:
|
884
785
|
self._page_backup(operation)
|
885
|
-
self.ack_backup_queue.put(operation
|
786
|
+
self.ack_backup_queue.put(operation)
|
886
787
|
|
887
788
|
except Empty:
|
888
789
|
continue
|