sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import TYPE_CHECKING, List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
|
|
33
33
|
get_tensor_model_parallel_world_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_dp_rank,
|
36
37
|
get_attention_tp_rank,
|
37
38
|
get_attention_tp_size,
|
38
39
|
is_dp_attention_enabled,
|
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class LayerLoadingEvent:
|
47
|
+
def __init__(self, num_layers: int):
|
48
|
+
self._num_layers = num_layers
|
49
|
+
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
50
|
+
self.start_event = torch.cuda.Event() # start event on controller stream
|
51
|
+
|
52
|
+
def complete(self, layer_index: int):
|
53
|
+
assert 0 <= layer_index < self._num_layers
|
54
|
+
self.load_events[layer_index].record()
|
55
|
+
|
56
|
+
def wait(self, layer_index: int):
|
57
|
+
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
58
|
+
|
59
|
+
@property
|
60
|
+
def finish_event(self):
|
61
|
+
return self.load_events[-1]
|
62
|
+
|
63
|
+
|
45
64
|
class LayerDoneCounter:
|
46
|
-
def __init__(self, num_layers):
|
65
|
+
def __init__(self, num_layers: int):
|
47
66
|
self.num_layers = num_layers
|
48
67
|
# extra producer and consumer counters for overlap mode
|
49
68
|
self.num_counters = 3
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
self.consumer_index = 0
|
54
|
-
|
55
|
-
def next_producer(self):
|
56
|
-
return (self.producer_index + 1) % self.num_counters
|
69
|
+
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
70
|
+
self.producer_index = -1
|
71
|
+
self.consumer_index = -1
|
57
72
|
|
58
73
|
def update_producer(self):
|
59
|
-
self.producer_index = self.
|
74
|
+
self.producer_index = (self.producer_index + 1) % self.num_counters
|
75
|
+
assert self.events[
|
76
|
+
self.producer_index
|
77
|
+
].finish_event.query(), (
|
78
|
+
"Producer finish event should be ready before being reused."
|
79
|
+
)
|
60
80
|
return self.producer_index
|
61
81
|
|
62
|
-
def set_consumer(self, index):
|
82
|
+
def set_consumer(self, index: int):
|
63
83
|
self.consumer_index = index
|
64
84
|
|
65
|
-
def
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
def wait_until(self, threshold):
|
71
|
-
with self.conditions[self.consumer_index]:
|
72
|
-
while self.counters[self.consumer_index] <= threshold:
|
73
|
-
self.conditions[self.consumer_index].wait()
|
85
|
+
def wait_until(self, threshold: int):
|
86
|
+
if self.consumer_index < 0:
|
87
|
+
return
|
88
|
+
self.events[self.consumer_index].wait(threshold)
|
74
89
|
|
75
90
|
def reset(self):
|
76
|
-
|
77
|
-
|
91
|
+
self.producer_index = -1
|
92
|
+
self.consumer_index = -1
|
78
93
|
|
79
94
|
|
80
95
|
class CacheOperation:
|
@@ -98,36 +113,30 @@ class CacheOperation:
|
|
98
113
|
# default priority is the order of creation
|
99
114
|
self.priority = priority if priority is not None else self.id
|
100
115
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
device_indices=self.device_indices[i : i + chunk_size],
|
120
|
-
node_id=0,
|
121
|
-
)
|
122
|
-
)
|
123
|
-
# Inherit the node_ids on the final chunk
|
124
|
-
if split_ops:
|
125
|
-
split_ops[-1].node_ids = self.node_ids
|
116
|
+
@staticmethod
|
117
|
+
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
118
|
+
assert len(ops) > 0
|
119
|
+
if len(ops) == 1:
|
120
|
+
return ops[0]
|
121
|
+
|
122
|
+
host_indices = torch.cat([op.host_indices for op in ops])
|
123
|
+
device_indices = torch.cat([op.device_indices for op in ops])
|
124
|
+
node_ids = []
|
125
|
+
priority = min(op.priority for op in ops)
|
126
|
+
for op in ops:
|
127
|
+
node_ids.extend(op.node_ids)
|
128
|
+
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
129
|
+
merged_op.node_ids = node_ids
|
130
|
+
return merged_op
|
131
|
+
|
132
|
+
def __lt__(self, other: CacheOperation):
|
133
|
+
return self.priority < other.priority
|
126
134
|
|
127
|
-
return split_ops
|
128
135
|
|
129
|
-
|
130
|
-
|
136
|
+
class HiCacheAck(NamedTuple):
|
137
|
+
start_event: torch.cuda.Event
|
138
|
+
finish_event: torch.cuda.Event
|
139
|
+
node_ids: List[int]
|
131
140
|
|
132
141
|
|
133
142
|
class TransferBuffer:
|
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
|
|
206
215
|
):
|
207
216
|
self.request_id = request_id
|
208
217
|
|
209
|
-
self._done_flag = False
|
210
218
|
self._lock = threading.Lock()
|
211
|
-
|
219
|
+
self._terminated_flag = False
|
212
220
|
self.start_time = time.monotonic()
|
213
221
|
|
214
222
|
super().__init__(host_indices, token_ids, last_hash)
|
215
223
|
|
216
224
|
def increment(self, num_tokens: int):
|
217
225
|
with self._lock:
|
218
|
-
if self.
|
226
|
+
if self._terminated_flag:
|
219
227
|
return False
|
220
228
|
self.completed_tokens += num_tokens
|
221
229
|
return True
|
222
230
|
|
223
|
-
def
|
231
|
+
def mark_terminate(self):
|
224
232
|
with self._lock:
|
225
|
-
self.
|
233
|
+
self._terminated_flag = True
|
226
234
|
|
227
|
-
def
|
228
|
-
return self.
|
235
|
+
def is_terminated(self) -> bool:
|
236
|
+
return self._terminated_flag
|
229
237
|
|
230
238
|
|
231
239
|
class HiCacheController:
|
@@ -236,13 +244,13 @@ class HiCacheController:
|
|
236
244
|
mem_pool_host: HostKVCache,
|
237
245
|
page_size: int,
|
238
246
|
tp_group: torch.distributed.ProcessGroup,
|
239
|
-
load_cache_event: threading.Event
|
247
|
+
load_cache_event: threading.Event,
|
240
248
|
write_policy: str = "write_through_selective",
|
241
249
|
io_backend: str = "",
|
242
250
|
storage_backend: Optional[str] = None,
|
243
251
|
prefetch_threshold: int = 256,
|
244
252
|
model_name: Optional[str] = None,
|
245
|
-
storage_backend_extra_config: Optional[
|
253
|
+
storage_backend_extra_config: Optional[dict] = None,
|
246
254
|
):
|
247
255
|
self.mem_pool_device_allocator = token_to_kv_pool_allocator
|
248
256
|
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
|
@@ -267,43 +275,17 @@ class HiCacheController:
|
|
267
275
|
and self.storage_config.tp_rank != 0
|
268
276
|
)
|
269
277
|
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
self.storage_backend = HiCacheFile(self.storage_config)
|
274
|
-
elif storage_backend == "nixl":
|
275
|
-
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
276
|
-
|
277
|
-
self.storage_backend = HiCacheNixl()
|
278
|
-
elif storage_backend == "mooncake":
|
279
|
-
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
280
|
-
MooncakeStore,
|
281
|
-
)
|
278
|
+
# Use storage backend factory for dynamic backend creation
|
279
|
+
from sglang.srt.mem_cache.storage import StorageBackendFactory
|
282
280
|
|
283
|
-
|
284
|
-
self.storage_backend.
|
285
|
-
|
286
|
-
elif storage_backend == "hf3fs":
|
287
|
-
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
288
|
-
HiCacheHF3FS,
|
281
|
+
try:
|
282
|
+
self.storage_backend = StorageBackendFactory.create_backend(
|
283
|
+
storage_backend, self.storage_config, self.mem_pool_host
|
289
284
|
)
|
285
|
+
except ValueError as e:
|
286
|
+
raise ValueError(f"Failed to create storage backend: {e}") from e
|
290
287
|
|
291
|
-
|
292
|
-
bytes_per_page = (
|
293
|
-
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
|
294
|
-
)
|
295
|
-
elif self.mem_pool_host.layout == "layer_first":
|
296
|
-
bytes_per_page = (
|
297
|
-
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
|
298
|
-
)
|
299
|
-
dtype = mem_pool_host.dtype
|
300
|
-
self.storage_backend = HiCacheHF3FS.from_env_config(
|
301
|
-
bytes_per_page, dtype, self.storage_config
|
302
|
-
)
|
303
|
-
else:
|
304
|
-
raise NotImplementedError(
|
305
|
-
f"Unsupported storage backend: {storage_backend}"
|
306
|
-
)
|
288
|
+
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
|
307
289
|
|
308
290
|
self.enable_storage = True
|
309
291
|
# todo: threshold policy for prefetching
|
@@ -327,21 +309,14 @@ class HiCacheController:
|
|
327
309
|
# Select the get and set functions
|
328
310
|
self.page_get_func = self._generic_page_get
|
329
311
|
self.page_set_func = self._generic_page_set
|
330
|
-
|
331
|
-
self.
|
332
|
-
self.
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
elif self.is_3fs_zerocopy:
|
339
|
-
self.page_get_func = self._3fs_zero_copy_page_get
|
340
|
-
self.page_set_func = self._3fs_zero_copy_page_set
|
341
|
-
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
342
|
-
|
343
|
-
self.load_cache_event = load_cache_event
|
344
|
-
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
312
|
+
|
313
|
+
if self.storage_backend_type in ["hf3fs", "mooncake", "eic"]:
|
314
|
+
self.page_get_func = self._page_get_zero_copy
|
315
|
+
self.page_set_func = self._page_set_zero_copy
|
316
|
+
|
317
|
+
self.device = self.mem_pool_device.device
|
318
|
+
self.layer_num = self.mem_pool_device.layer_num
|
319
|
+
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
345
320
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
346
321
|
|
347
322
|
if write_policy not in [
|
@@ -351,11 +326,11 @@ class HiCacheController:
|
|
351
326
|
]:
|
352
327
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
353
328
|
|
354
|
-
self.write_queue = PriorityQueue()
|
355
|
-
self.load_queue =
|
356
|
-
|
357
|
-
self.
|
358
|
-
self.
|
329
|
+
# self.write_queue = PriorityQueue[CacheOperation]()
|
330
|
+
self.load_queue: List[CacheOperation] = []
|
331
|
+
self.write_queue: List[CacheOperation] = []
|
332
|
+
self.ack_load_queue: List[HiCacheAck] = []
|
333
|
+
self.ack_write_queue: List[HiCacheAck] = []
|
359
334
|
|
360
335
|
self.stop_event = threading.Event()
|
361
336
|
self.write_buffer = TransferBuffer(self.stop_event)
|
@@ -366,16 +341,6 @@ class HiCacheController:
|
|
366
341
|
self.write_stream = torch.cuda.Stream()
|
367
342
|
self.load_stream = torch.cuda.Stream()
|
368
343
|
|
369
|
-
self.write_thread = threading.Thread(
|
370
|
-
target=self.write_thread_func_direct, daemon=True
|
371
|
-
)
|
372
|
-
self.load_thread = threading.Thread(
|
373
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
374
|
-
)
|
375
|
-
|
376
|
-
self.write_thread.start()
|
377
|
-
self.load_thread.start()
|
378
|
-
|
379
344
|
if self.enable_storage:
|
380
345
|
self.prefetch_thread = threading.Thread(
|
381
346
|
target=self.prefetch_thread_func, daemon=True
|
@@ -396,49 +361,39 @@ class HiCacheController:
|
|
396
361
|
def _generate_storage_config(
|
397
362
|
self,
|
398
363
|
model_name: Optional[str] = None,
|
399
|
-
storage_backend_extra_config: Optional[
|
364
|
+
storage_backend_extra_config: Optional[dict] = None,
|
400
365
|
):
|
401
366
|
|
402
367
|
if is_dp_attention_enabled():
|
403
368
|
self.tp_rank = get_attention_tp_rank()
|
404
369
|
self.tp_size = get_attention_tp_size()
|
370
|
+
self.dp_rank = get_attention_dp_rank()
|
405
371
|
else:
|
406
372
|
self.tp_rank = get_tensor_model_parallel_rank()
|
407
373
|
self.tp_size = get_tensor_model_parallel_world_size()
|
374
|
+
self.dp_rank = 0
|
408
375
|
|
409
376
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
410
377
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
411
378
|
|
412
|
-
# Parse extra config JSON if provided
|
413
|
-
extra_config = None
|
414
|
-
if storage_backend_extra_config:
|
415
|
-
try:
|
416
|
-
import json
|
417
|
-
|
418
|
-
extra_config = json.loads(storage_backend_extra_config)
|
419
|
-
except Exception as e:
|
420
|
-
logger.error(f"Invalid backend extra config JSON: {e}")
|
421
|
-
|
422
379
|
return HiCacheStorageConfig(
|
423
380
|
tp_rank=self.tp_rank,
|
424
381
|
tp_size=self.tp_size,
|
425
382
|
is_mla_model=is_mla_backend,
|
426
383
|
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
427
384
|
model_name=model_name,
|
428
|
-
extra_config=
|
385
|
+
extra_config=storage_backend_extra_config,
|
429
386
|
)
|
430
387
|
|
431
388
|
def reset(self):
|
432
389
|
self.stop_event.set()
|
433
|
-
self.write_thread.join()
|
434
|
-
self.load_thread.join()
|
435
390
|
|
436
|
-
self.write_queue.
|
437
|
-
self.load_queue.
|
391
|
+
self.write_queue.clear()
|
392
|
+
self.load_queue.clear()
|
438
393
|
self.write_buffer.clear()
|
439
394
|
self.load_buffer.clear()
|
440
|
-
self.ack_write_queue.
|
441
|
-
self.ack_load_queue.
|
395
|
+
self.ack_write_queue.clear()
|
396
|
+
self.ack_load_queue.clear()
|
442
397
|
if self.enable_storage:
|
443
398
|
self.prefetch_thread.join()
|
444
399
|
self.backup_thread.join()
|
@@ -447,15 +402,7 @@ class HiCacheController:
|
|
447
402
|
self.prefetch_revoke_queue.queue.clear()
|
448
403
|
self.ack_backup_queue.queue.clear()
|
449
404
|
|
450
|
-
self.write_thread = threading.Thread(
|
451
|
-
target=self.write_thread_func_direct, daemon=True
|
452
|
-
)
|
453
|
-
self.load_thread = threading.Thread(
|
454
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
455
|
-
)
|
456
405
|
self.stop_event.clear()
|
457
|
-
self.write_thread.start()
|
458
|
-
self.load_thread.start()
|
459
406
|
|
460
407
|
if self.enable_storage:
|
461
408
|
self.prefetch_thread = threading.Thread(
|
@@ -471,7 +418,7 @@ class HiCacheController:
|
|
471
418
|
self,
|
472
419
|
device_indices: torch.Tensor,
|
473
420
|
priority: Optional[int] = None,
|
474
|
-
node_id: int =
|
421
|
+
node_id: int = -1,
|
475
422
|
) -> Optional[torch.Tensor]:
|
476
423
|
"""
|
477
424
|
Back up KV caches from device memory to host memory.
|
@@ -479,18 +426,45 @@ class HiCacheController:
|
|
479
426
|
host_indices = self.mem_pool_host.alloc(len(device_indices))
|
480
427
|
if host_indices is None:
|
481
428
|
return None
|
482
|
-
self.
|
483
|
-
torch.cuda.current_stream().synchronize()
|
484
|
-
self.write_queue.put(
|
429
|
+
self.write_queue.append(
|
485
430
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
486
431
|
)
|
432
|
+
self.start_writing()
|
487
433
|
return host_indices
|
488
434
|
|
435
|
+
def start_writing(self) -> None:
|
436
|
+
if len(self.write_queue) == 0:
|
437
|
+
return
|
438
|
+
|
439
|
+
op = CacheOperation.merge_ops(self.write_queue)
|
440
|
+
host_indices, device_indices = self.move_indices(op)
|
441
|
+
self.write_queue.clear()
|
442
|
+
|
443
|
+
start_event = torch.cuda.Event()
|
444
|
+
finish_event = torch.cuda.Event()
|
445
|
+
|
446
|
+
start_event.record()
|
447
|
+
with torch.cuda.stream(self.write_stream):
|
448
|
+
start_event.wait(self.write_stream)
|
449
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
450
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
451
|
+
)
|
452
|
+
finish_event.record()
|
453
|
+
# NOTE: We must save the host indices and device indices here,
|
454
|
+
# this is because we need to guarantee that these tensors are
|
455
|
+
# still alive when the write stream is executing.
|
456
|
+
if host_indices.is_cuda:
|
457
|
+
host_indices.record_stream(self.write_stream)
|
458
|
+
if device_indices.is_cuda:
|
459
|
+
device_indices.record_stream(self.write_stream)
|
460
|
+
|
461
|
+
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
462
|
+
|
489
463
|
def load(
|
490
464
|
self,
|
491
465
|
host_indices: torch.Tensor,
|
492
466
|
priority: Optional[int] = None,
|
493
|
-
node_id: int =
|
467
|
+
node_id: int = -1,
|
494
468
|
) -> Optional[torch.Tensor]:
|
495
469
|
"""
|
496
470
|
Load KV caches from host memory to device memory.
|
@@ -498,77 +472,42 @@ class HiCacheController:
|
|
498
472
|
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
|
499
473
|
if device_indices is None:
|
500
474
|
return None
|
501
|
-
self.
|
502
|
-
# to ensure the device indices are ready before accessed by another CUDA stream
|
503
|
-
torch.cuda.current_stream().synchronize()
|
504
|
-
self.load_queue.put(
|
475
|
+
self.load_queue.append(
|
505
476
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
506
477
|
)
|
507
478
|
return device_indices
|
508
479
|
|
509
|
-
def move_indices(self,
|
480
|
+
def move_indices(self, op: CacheOperation):
|
481
|
+
host_indices, device_indices = op.host_indices, op.device_indices
|
510
482
|
# move indices to GPU if using kernels, to host if using direct indexing
|
511
483
|
if self.io_backend == "kernel":
|
512
|
-
|
484
|
+
if not host_indices.is_cuda:
|
485
|
+
host_indices = host_indices.to(self.device, non_blocking=True)
|
486
|
+
return host_indices, device_indices
|
513
487
|
elif self.io_backend == "direct":
|
514
|
-
|
515
|
-
|
516
|
-
|
488
|
+
if self.mem_pool_host.layout == "layer_first":
|
489
|
+
device_indices = device_indices.cpu()
|
490
|
+
host_indices, idx = host_indices.sort()
|
491
|
+
return host_indices, device_indices.index_select(0, idx)
|
492
|
+
elif self.mem_pool_host.layout == "page_first_direct":
|
493
|
+
return host_indices, device_indices.cpu()
|
517
494
|
else:
|
518
495
|
raise ValueError(f"Unsupported io backend")
|
519
496
|
|
520
|
-
def
|
521
|
-
|
522
|
-
|
523
|
-
"""
|
524
|
-
torch.cuda.set_stream(self.write_stream)
|
525
|
-
while not self.stop_event.is_set():
|
526
|
-
try:
|
527
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
528
|
-
host_indices, device_indices = self.move_indices(
|
529
|
-
operation.host_indices, operation.device_indices
|
530
|
-
)
|
531
|
-
self.mem_pool_host.backup_from_device_all_layer(
|
532
|
-
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
533
|
-
)
|
534
|
-
self.write_stream.synchronize()
|
535
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
536
|
-
for node_id in operation.node_ids:
|
537
|
-
if node_id != 0:
|
538
|
-
self.ack_write_queue.put(node_id)
|
539
|
-
except Empty:
|
540
|
-
continue
|
541
|
-
except Exception as e:
|
542
|
-
logger.error(e)
|
497
|
+
def start_loading(self) -> int:
|
498
|
+
if len(self.load_queue) == 0:
|
499
|
+
return -1
|
543
500
|
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
self.load_cache_event.wait(timeout=1)
|
551
|
-
if not self.load_cache_event.is_set():
|
552
|
-
continue
|
553
|
-
self.load_cache_event.clear()
|
554
|
-
self.layer_done_counter.update_producer()
|
555
|
-
|
556
|
-
batch_operation = None
|
557
|
-
while self.load_queue.qsize() > 0:
|
558
|
-
op = self.load_queue.get(block=True)
|
559
|
-
if batch_operation is None:
|
560
|
-
batch_operation = op
|
561
|
-
else:
|
562
|
-
batch_operation.merge(op)
|
563
|
-
if batch_operation is None:
|
564
|
-
continue
|
501
|
+
producer_id = self.layer_done_counter.update_producer()
|
502
|
+
op = CacheOperation.merge_ops(self.load_queue)
|
503
|
+
host_indices, device_indices = self.move_indices(op)
|
504
|
+
self.load_queue.clear()
|
505
|
+
producer_event = self.layer_done_counter.events[producer_id]
|
506
|
+
producer_event.start_event.record()
|
565
507
|
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
batch_operation.host_indices, batch_operation.device_indices
|
570
|
-
)
|
571
|
-
for i in range(self.mem_pool_host.layer_num):
|
508
|
+
with torch.cuda.stream(self.load_stream):
|
509
|
+
producer_event.start_event.wait(self.load_stream)
|
510
|
+
for i in range(self.layer_num):
|
572
511
|
self.mem_pool_host.load_to_device_per_layer(
|
573
512
|
self.mem_pool_device,
|
574
513
|
host_indices,
|
@@ -576,37 +515,34 @@ class HiCacheController:
|
|
576
515
|
i,
|
577
516
|
self.io_backend,
|
578
517
|
)
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
self
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
return len(device_indices)
|
594
|
-
else:
|
595
|
-
raise ValueError(
|
596
|
-
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
518
|
+
producer_event.complete(i)
|
519
|
+
# NOTE: We must save the host indices and device indices here,
|
520
|
+
# this is because we need to guarantee that these tensors are
|
521
|
+
# still alive when the load stream is executing.
|
522
|
+
if host_indices.is_cuda:
|
523
|
+
host_indices.record_stream(self.load_stream)
|
524
|
+
if device_indices.is_cuda:
|
525
|
+
device_indices.record_stream(self.load_stream)
|
526
|
+
|
527
|
+
self.ack_load_queue.append(
|
528
|
+
HiCacheAck(
|
529
|
+
start_event=producer_event.start_event,
|
530
|
+
finish_event=producer_event.finish_event,
|
531
|
+
node_ids=op.node_ids,
|
597
532
|
)
|
533
|
+
)
|
534
|
+
return producer_id
|
535
|
+
|
536
|
+
def evict_device(self, device_indices: torch.Tensor) -> int:
|
537
|
+
self.mem_pool_device_allocator.free(device_indices)
|
538
|
+
return len(device_indices)
|
598
539
|
|
599
540
|
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
|
600
541
|
if not backup_only:
|
601
542
|
raise ValueError("Other eviction policies are not supported yet.")
|
602
543
|
|
603
|
-
|
604
|
-
|
605
|
-
return len(host_indices)
|
606
|
-
else:
|
607
|
-
raise ValueError(
|
608
|
-
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
|
609
|
-
)
|
544
|
+
self.mem_pool_host.free(host_indices)
|
545
|
+
return len(host_indices)
|
610
546
|
|
611
547
|
def prefetch(
|
612
548
|
self,
|
@@ -625,50 +561,29 @@ class HiCacheController:
|
|
625
561
|
return operation
|
626
562
|
|
627
563
|
def terminate_prefetch(self, operation):
|
628
|
-
operation.
|
564
|
+
operation.mark_terminate()
|
629
565
|
return operation.completed_tokens, operation.hash_value
|
630
566
|
|
631
567
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
638
|
-
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
639
|
-
return hit_page_num
|
640
|
-
|
641
|
-
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
642
|
-
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
643
|
-
hash_values, host_indices
|
644
|
-
)
|
645
|
-
page_data = self.storage_backend.batch_get(hashes, dsts)
|
646
|
-
if page_data:
|
647
|
-
inc = self.page_size * len(hashes) // factor
|
648
|
-
operation.increment(inc)
|
649
|
-
else:
|
650
|
-
logger.warning(
|
651
|
-
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
652
|
-
)
|
568
|
+
if host_indices.numel() == 0:
|
569
|
+
return
|
570
|
+
pages = host_indices.split(self.mem_pool_host.page_size)
|
571
|
+
for page in pages:
|
572
|
+
self.host_mem_release_queue.put(page)
|
653
573
|
|
654
|
-
def
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
)
|
665
|
-
if get_result != len(hash_values):
|
666
|
-
logger.warning(
|
667
|
-
f"Prefetch operation {operation.request_id} failed or partially failed."
|
668
|
-
)
|
669
|
-
if get_result != 0:
|
670
|
-
operation.increment(get_result * self.page_size)
|
574
|
+
def _page_get_zero_copy(self, operation, hash_values, host_indices):
|
575
|
+
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
|
576
|
+
inc = 0
|
577
|
+
for i in range(len(hash_values)):
|
578
|
+
if not results[i]:
|
579
|
+
logger.warning(
|
580
|
+
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
581
|
+
)
|
582
|
+
break
|
583
|
+
inc += self.page_size
|
584
|
+
operation.increment(inc)
|
671
585
|
|
586
|
+
# todo: deprecate
|
672
587
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
673
588
|
dummy_page_dst = [
|
674
589
|
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
@@ -706,6 +621,7 @@ class HiCacheController:
|
|
706
621
|
operation.completed_tokens
|
707
622
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
708
623
|
):
|
624
|
+
operation.mark_terminate()
|
709
625
|
break # Some operations fail or operation terminated by controller
|
710
626
|
# release pre-allocated memory
|
711
627
|
self.append_host_mem_release(
|
@@ -757,7 +673,7 @@ class HiCacheController:
|
|
757
673
|
batch_tokens[i : i + self.page_size], last_hash
|
758
674
|
)
|
759
675
|
batch_hashes.append(last_hash)
|
760
|
-
hit_page_num = self.
|
676
|
+
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
|
761
677
|
hash_value.extend(batch_hashes[:hit_page_num])
|
762
678
|
storage_query_count += hit_page_num * self.page_size
|
763
679
|
if hit_page_num < len(batch_hashes):
|
@@ -826,34 +742,16 @@ class HiCacheController:
|
|
826
742
|
self.backup_queue.put(operation)
|
827
743
|
return operation.id
|
828
744
|
|
829
|
-
#
|
745
|
+
# todo: deprecate
|
830
746
|
def _generic_page_set(self, hash_values, host_indices) -> bool:
|
831
747
|
data = [
|
832
|
-
self.mem_pool_host.
|
748
|
+
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
|
833
749
|
for i in range(len(hash_values))
|
834
750
|
]
|
835
751
|
return self.storage_backend.batch_set(hash_values, data)
|
836
752
|
|
837
|
-
|
838
|
-
|
839
|
-
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
|
840
|
-
hash_values,
|
841
|
-
host_indices,
|
842
|
-
self.storage_config.tp_rank,
|
843
|
-
)
|
844
|
-
success = self.storage_backend.batch_set(
|
845
|
-
key_strs,
|
846
|
-
target_locations=buffer_ptrs,
|
847
|
-
target_sizes=buffer_sizes,
|
848
|
-
)
|
849
|
-
return success
|
850
|
-
|
851
|
-
# zero copy
|
852
|
-
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
853
|
-
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
854
|
-
hash_values, host_indices
|
855
|
-
)
|
856
|
-
return self.storage_backend.batch_set(hashes, dsts)
|
753
|
+
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
|
754
|
+
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
|
857
755
|
|
858
756
|
# Backup batch by batch
|
859
757
|
def _page_backup(self, operation):
|
@@ -885,7 +783,7 @@ class HiCacheController:
|
|
885
783
|
|
886
784
|
if not self.backup_skip:
|
887
785
|
self._page_backup(operation)
|
888
|
-
self.ack_backup_queue.put(operation
|
786
|
+
self.ack_backup_queue.put(operation)
|
889
787
|
|
890
788
|
except Empty:
|
891
789
|
continue
|