sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import dataclasses
|
5
4
|
import logging
|
6
|
-
import
|
7
|
-
import socket
|
5
|
+
import os
|
8
6
|
import struct
|
9
7
|
import threading
|
8
|
+
import time
|
10
9
|
import uuid
|
11
10
|
from collections import defaultdict
|
12
|
-
from
|
13
|
-
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
11
|
+
from typing import Dict, List, Optional, Set
|
14
12
|
|
15
13
|
import numpy as np
|
16
14
|
import numpy.typing as npt
|
17
15
|
import requests
|
18
|
-
import zmq
|
19
|
-
from aiohttp import web
|
20
16
|
|
21
|
-
from sglang.srt.disaggregation.base.conn import
|
17
|
+
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
22
18
|
from sglang.srt.disaggregation.common.conn import (
|
23
19
|
CommonKVBootstrapServer,
|
24
20
|
CommonKVManager,
|
25
21
|
CommonKVReceiver,
|
22
|
+
CommonKVSender,
|
26
23
|
)
|
27
24
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
25
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
26
|
from sglang.srt.server_args import ServerArgs
|
30
|
-
from sglang.srt.utils import
|
31
|
-
format_tcp_address,
|
32
|
-
get_local_ip_auto,
|
33
|
-
is_valid_ipv6_address,
|
34
|
-
)
|
27
|
+
from sglang.srt.utils import get_int_env_var
|
35
28
|
|
36
29
|
logger = logging.getLogger(__name__)
|
37
30
|
|
@@ -78,6 +71,9 @@ class KVArgsRegisterInfo:
|
|
78
71
|
dst_kv_ptrs: list[int]
|
79
72
|
dst_aux_ptrs: list[int]
|
80
73
|
gpu_id: int
|
74
|
+
decode_tp_size: int
|
75
|
+
decode_tp_rank: int
|
76
|
+
dst_kv_item_len: int
|
81
77
|
|
82
78
|
@classmethod
|
83
79
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -90,6 +86,9 @@ class KVArgsRegisterInfo:
|
|
90
86
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
91
87
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
92
88
|
gpu_id=int(msg[7].decode("ascii")),
|
89
|
+
decode_tp_size=int(msg[8].decode("ascii")),
|
90
|
+
decode_tp_rank=int(msg[9].decode("ascii")),
|
91
|
+
dst_kv_item_len=int(msg[10].decode("ascii")),
|
93
92
|
)
|
94
93
|
|
95
94
|
|
@@ -107,8 +106,14 @@ class TransferStatus:
|
|
107
106
|
def is_done(self):
|
108
107
|
if self.num_kvs_expected is None:
|
109
108
|
return False
|
109
|
+
# Check for failure state
|
110
|
+
if self.num_kvs_expected == -1:
|
111
|
+
return True # Failed transfers are considered "done"
|
110
112
|
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
111
113
|
|
114
|
+
def is_failed(self):
|
115
|
+
return self.num_kvs_expected == -1
|
116
|
+
|
112
117
|
|
113
118
|
class NixlKVManager(CommonKVManager):
|
114
119
|
def __init__(
|
@@ -128,26 +133,133 @@ class NixlKVManager(CommonKVManager):
|
|
128
133
|
"to run SGLang with NixlTransferEngine."
|
129
134
|
) from e
|
130
135
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
131
|
-
self.local_ip = get_local_ip_auto()
|
132
|
-
self.server_socket = zmq.Context().socket(zmq.PULL)
|
133
|
-
if is_valid_ipv6_address(self.local_ip):
|
134
|
-
self.server_socket.setsockopt(zmq.IPV6, 1)
|
135
136
|
self.register_buffer_to_engine()
|
136
137
|
|
137
138
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
138
|
-
self.request_status: Dict[int, KVPoll] = {}
|
139
|
-
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
140
|
-
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
141
139
|
self._start_bootstrap_thread()
|
142
140
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
143
141
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
144
142
|
TransferStatus
|
145
143
|
)
|
144
|
+
self.heartbeat_failures = {}
|
145
|
+
self.session_pool = defaultdict(requests.Session)
|
146
|
+
self.session_pool_lock = threading.Lock()
|
147
|
+
self.addr_to_rooms_tracker = defaultdict(set)
|
148
|
+
self.connection_lock = threading.Lock()
|
149
|
+
|
150
|
+
# Heartbeat interval should be at least 2 seconds
|
151
|
+
self.heartbeat_interval = max(
|
152
|
+
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
153
|
+
)
|
154
|
+
# Heartbeat failure should be at least 1
|
155
|
+
self.max_failures = max(
|
156
|
+
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
157
|
+
)
|
158
|
+
self._start_heartbeat_checker_thread()
|
146
159
|
else:
|
147
160
|
raise ValueError(
|
148
161
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
149
162
|
)
|
150
163
|
|
164
|
+
def _start_heartbeat_checker_thread(self):
|
165
|
+
"""
|
166
|
+
Start the heartbeat checker thread for Decode worker.
|
167
|
+
TODO (smor): unite nixl heartbeat checker with mooncake's.
|
168
|
+
"""
|
169
|
+
|
170
|
+
def heartbeat_checker():
|
171
|
+
while True:
|
172
|
+
time.sleep(self.heartbeat_interval)
|
173
|
+
with self.connection_lock:
|
174
|
+
addresses = list(self.prefill_dp_size_table.keys())
|
175
|
+
|
176
|
+
for bootstrap_addr in addresses:
|
177
|
+
session = None
|
178
|
+
try:
|
179
|
+
with self.session_pool_lock:
|
180
|
+
session = self.session_pool[bootstrap_addr]
|
181
|
+
response = session.get(
|
182
|
+
f"http://{bootstrap_addr}/health",
|
183
|
+
timeout=(2, 3),
|
184
|
+
headers={"Connection": "keep-alive"},
|
185
|
+
)
|
186
|
+
if response.status_code == 200:
|
187
|
+
self.heartbeat_failures[bootstrap_addr] = 0
|
188
|
+
|
189
|
+
current_rooms = self.addr_to_rooms_tracker[
|
190
|
+
bootstrap_addr
|
191
|
+
].copy()
|
192
|
+
|
193
|
+
for bootstrap_room in current_rooms:
|
194
|
+
# Remove successful transfers from the tracker
|
195
|
+
if bootstrap_room not in self.transfer_statuses:
|
196
|
+
self.addr_to_rooms_tracker[bootstrap_addr].discard(
|
197
|
+
bootstrap_room
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
logger.info(
|
201
|
+
f"Attempting to reconnect to {bootstrap_addr}..."
|
202
|
+
)
|
203
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
204
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
205
|
+
)
|
206
|
+
with self.session_pool_lock:
|
207
|
+
if bootstrap_addr in self.session_pool:
|
208
|
+
del self.session_pool[bootstrap_addr]
|
209
|
+
except Exception:
|
210
|
+
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
|
211
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
212
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
213
|
+
)
|
214
|
+
|
215
|
+
if (
|
216
|
+
self.heartbeat_failures.get(bootstrap_addr, 0)
|
217
|
+
>= self.max_failures
|
218
|
+
):
|
219
|
+
self._handle_node_failure(bootstrap_addr)
|
220
|
+
with self.session_pool_lock:
|
221
|
+
if bootstrap_addr in self.session_pool:
|
222
|
+
del self.session_pool[bootstrap_addr]
|
223
|
+
|
224
|
+
threading.Thread(target=heartbeat_checker, daemon=True).start()
|
225
|
+
|
226
|
+
def _handle_node_failure(self, failed_bootstrap_addr):
|
227
|
+
"""Handle failure of a prefill node."""
|
228
|
+
with self.connection_lock:
|
229
|
+
keys_to_remove = [
|
230
|
+
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
|
231
|
+
]
|
232
|
+
for k in keys_to_remove:
|
233
|
+
del self.connection_pool[k]
|
234
|
+
if failed_bootstrap_addr in self.prefill_tp_size_table:
|
235
|
+
del self.prefill_tp_size_table[failed_bootstrap_addr]
|
236
|
+
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
237
|
+
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
238
|
+
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
239
|
+
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
240
|
+
|
241
|
+
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
242
|
+
failed_bootstrap_addr, []
|
243
|
+
)
|
244
|
+
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
|
245
|
+
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
|
246
|
+
|
247
|
+
# Mark all pending transfers associated with the failed node as failed
|
248
|
+
affected_rooms = []
|
249
|
+
for room in possible_affected_rooms:
|
250
|
+
if (
|
251
|
+
room in self.transfer_statuses
|
252
|
+
and not self.transfer_statuses[room].is_done()
|
253
|
+
):
|
254
|
+
# Mark the transfer as failed by setting a special state
|
255
|
+
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
|
256
|
+
affected_rooms.append(room)
|
257
|
+
|
258
|
+
logger.error(
|
259
|
+
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
|
260
|
+
f"{len(affected_rooms)} transfers affected"
|
261
|
+
)
|
262
|
+
|
151
263
|
def check_status(self, bootstrap_room: int):
|
152
264
|
return self.request_status[bootstrap_room]
|
153
265
|
|
@@ -160,13 +272,16 @@ class NixlKVManager(CommonKVManager):
|
|
160
272
|
self.request_status[bootstrap_room], status
|
161
273
|
)
|
162
274
|
|
275
|
+
def record_failure(self, bootstrap_room: int, failure_reason: str):
|
276
|
+
pass
|
277
|
+
|
163
278
|
def register_buffer_to_engine(self):
|
164
279
|
kv_addrs = []
|
165
280
|
for kv_data_ptr, kv_data_len in zip(
|
166
281
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
167
282
|
):
|
168
283
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
169
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM"
|
284
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
170
285
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
171
286
|
if not self.kv_descs:
|
172
287
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -175,7 +290,7 @@ class NixlKVManager(CommonKVManager):
|
|
175
290
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
176
291
|
):
|
177
292
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
178
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM"
|
293
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
179
294
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
180
295
|
if not self.aux_descs:
|
181
296
|
raise Exception("NIXL memory registration failed for aux tensors")
|
@@ -222,8 +337,8 @@ class NixlKVManager(CommonKVManager):
|
|
222
337
|
logger.debug(
|
223
338
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
224
339
|
)
|
225
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM"
|
226
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM"
|
340
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
341
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
227
342
|
# Transfer data
|
228
343
|
xfer_handle = self.agent.initialize_xfer(
|
229
344
|
"WRITE",
|
@@ -239,6 +354,140 @@ class NixlKVManager(CommonKVManager):
|
|
239
354
|
raise Exception("KVSender failed to post transfer")
|
240
355
|
return xfer_handle
|
241
356
|
|
357
|
+
def send_kvcache_slice(
|
358
|
+
self,
|
359
|
+
peer_name: str,
|
360
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
361
|
+
dst_kv_ptrs: list[int],
|
362
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
363
|
+
dst_gpu_id: int,
|
364
|
+
notif: str,
|
365
|
+
prefill_tp_size: int,
|
366
|
+
decode_tp_size: int,
|
367
|
+
decode_tp_rank: int,
|
368
|
+
dst_kv_item_len: int,
|
369
|
+
):
|
370
|
+
# Get configuration from kv_args
|
371
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
372
|
+
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
373
|
+
num_kv_heads = self.kv_args.kv_head_num
|
374
|
+
|
375
|
+
# Calculate head distribution
|
376
|
+
src_heads_per_rank = num_kv_heads
|
377
|
+
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
378
|
+
|
379
|
+
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
380
|
+
page_size = self.kv_args.page_size
|
381
|
+
|
382
|
+
bytes_per_head_slice_to_send = (
|
383
|
+
dst_kv_item_len // page_size // dst_heads_per_rank
|
384
|
+
)
|
385
|
+
|
386
|
+
# Determine which heads to send
|
387
|
+
if prefill_tp_size > decode_tp_size:
|
388
|
+
# Multiple prefill ranks to one decode rank
|
389
|
+
src_head_start_offset = 0
|
390
|
+
num_heads_to_send = src_heads_per_rank
|
391
|
+
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
392
|
+
else:
|
393
|
+
# Send KVCache from 1 prefill instance to multiple decode instances
|
394
|
+
src_head_start_offset = (
|
395
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
396
|
+
) % src_heads_per_rank
|
397
|
+
num_heads_to_send = dst_heads_per_rank
|
398
|
+
dst_head_start_offset = 0
|
399
|
+
|
400
|
+
# Create transfer descriptors
|
401
|
+
src_addrs = []
|
402
|
+
dst_addrs = []
|
403
|
+
|
404
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
405
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
406
|
+
|
407
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
408
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
409
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
410
|
+
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
411
|
+
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
412
|
+
|
413
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
414
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
415
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
416
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
417
|
+
|
418
|
+
src_dst_ptr_pairs = [
|
419
|
+
(
|
420
|
+
src_k_ptrs[layer_id],
|
421
|
+
dst_k_ptrs[layer_id],
|
422
|
+
)
|
423
|
+
for layer_id in range(len(src_k_ptrs))
|
424
|
+
] + [
|
425
|
+
(
|
426
|
+
src_v_ptrs[layer_id],
|
427
|
+
dst_v_ptrs[layer_id],
|
428
|
+
)
|
429
|
+
for layer_id in range(len(src_v_ptrs))
|
430
|
+
]
|
431
|
+
|
432
|
+
src_addrs = []
|
433
|
+
dst_addrs = []
|
434
|
+
|
435
|
+
# Calculate strides for a single token slot
|
436
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
437
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
438
|
+
|
439
|
+
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
440
|
+
for i in range(len(prefill_kv_indices)):
|
441
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
442
|
+
decode_page_idx = int(dst_kv_indices[i])
|
443
|
+
|
444
|
+
# Get the starting addresses for the current src and dst pages
|
445
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
446
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
447
|
+
|
448
|
+
# Iterate through each valid token slot within the current page
|
449
|
+
for token_slot_in_page in range(page_size):
|
450
|
+
# Calculate the start address of the current token slot
|
451
|
+
src_token_slot_start_addr = (
|
452
|
+
src_page_start_addr
|
453
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
454
|
+
)
|
455
|
+
dst_token_slot_start_addr = (
|
456
|
+
dst_page_start_addr
|
457
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
458
|
+
)
|
459
|
+
|
460
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
461
|
+
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
462
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
463
|
+
|
464
|
+
src_addrs.append(
|
465
|
+
(
|
466
|
+
src_slice_addr,
|
467
|
+
heads_bytes_per_token_to_send,
|
468
|
+
self.kv_args.gpu_id,
|
469
|
+
)
|
470
|
+
)
|
471
|
+
dst_addrs.append(
|
472
|
+
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
473
|
+
)
|
474
|
+
|
475
|
+
# Use NIXL agent for transfer
|
476
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
477
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
478
|
+
|
479
|
+
xfer_handle = self.agent.initialize_xfer(
|
480
|
+
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
481
|
+
)
|
482
|
+
if not xfer_handle:
|
483
|
+
raise Exception("Failed to create sliced KV transfer")
|
484
|
+
|
485
|
+
state = self.agent.transfer(xfer_handle)
|
486
|
+
if state == "ERR":
|
487
|
+
raise Exception("Failed to post sliced KV transfer")
|
488
|
+
|
489
|
+
return xfer_handle
|
490
|
+
|
242
491
|
def send_aux(
|
243
492
|
self,
|
244
493
|
peer_name: str,
|
@@ -255,8 +504,8 @@ class NixlKVManager(CommonKVManager):
|
|
255
504
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
256
505
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
257
506
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
258
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM"
|
259
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM"
|
507
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
508
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
260
509
|
# Transfer data
|
261
510
|
xfer_handle = self.agent.initialize_xfer(
|
262
511
|
"WRITE",
|
@@ -296,14 +545,35 @@ class NixlKVManager(CommonKVManager):
|
|
296
545
|
assert req.agent_name in self.decode_kv_args_table
|
297
546
|
|
298
547
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
548
|
+
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
549
|
+
|
550
|
+
if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
|
551
|
+
kv_xfer_handle = self.send_kvcache(
|
552
|
+
req.agent_name,
|
553
|
+
kv_indices,
|
554
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
555
|
+
chunked_dst_kv_indice,
|
556
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
557
|
+
notif,
|
558
|
+
)
|
559
|
+
else:
|
560
|
+
kv_xfer_handle = self.send_kvcache_slice(
|
561
|
+
req.agent_name,
|
562
|
+
kv_indices,
|
563
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
564
|
+
chunked_dst_kv_indice,
|
565
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
566
|
+
notif,
|
567
|
+
prefill_tp_size=self.attn_tp_size,
|
568
|
+
decode_tp_size=decode_tp_size,
|
569
|
+
decode_tp_rank=self.decode_kv_args_table[
|
570
|
+
req.agent_name
|
571
|
+
].decode_tp_rank,
|
572
|
+
dst_kv_item_len=self.decode_kv_args_table[
|
573
|
+
req.agent_name
|
574
|
+
].dst_kv_item_len,
|
575
|
+
)
|
576
|
+
|
307
577
|
handles.append(kv_xfer_handle)
|
308
578
|
# Only the last chunk we need to send the aux data.
|
309
579
|
if is_last:
|
@@ -344,9 +614,6 @@ class NixlKVManager(CommonKVManager):
|
|
344
614
|
return False
|
345
615
|
return self.transfer_statuses[room].is_done()
|
346
616
|
|
347
|
-
def _bind_server_socket(self):
|
348
|
-
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
349
|
-
|
350
617
|
def _start_bootstrap_thread(self):
|
351
618
|
self._bind_server_socket()
|
352
619
|
|
@@ -387,7 +654,7 @@ class NixlKVManager(CommonKVManager):
|
|
387
654
|
threading.Thread(target=bootstrap_thread).start()
|
388
655
|
|
389
656
|
|
390
|
-
class NixlKVSender(
|
657
|
+
class NixlKVSender(CommonKVSender):
|
391
658
|
|
392
659
|
def __init__(
|
393
660
|
self,
|
@@ -397,20 +664,10 @@ class NixlKVSender(BaseKVSender):
|
|
397
664
|
dest_tp_ranks: List[int],
|
398
665
|
pp_rank: int,
|
399
666
|
):
|
400
|
-
|
401
|
-
self.bootstrap_room = bootstrap_room
|
402
|
-
self.aux_index = None
|
403
|
-
self.bootstrap_server_url = bootstrap_addr
|
667
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
|
404
668
|
self.xfer_handles = []
|
405
669
|
self.has_sent = False
|
406
670
|
self.chunk_id = 0
|
407
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
408
|
-
# inner state
|
409
|
-
self.curr_idx = 0
|
410
|
-
|
411
|
-
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
412
|
-
self.num_kv_indices = num_kv_indices
|
413
|
-
self.aux_index = aux_index
|
414
671
|
|
415
672
|
def send(
|
416
673
|
self,
|
@@ -454,11 +711,17 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
454
711
|
mgr: NixlKVManager,
|
455
712
|
bootstrap_addr: str,
|
456
713
|
bootstrap_room: Optional[int] = None,
|
457
|
-
|
714
|
+
prefill_dp_rank: Optional[int] = None,
|
458
715
|
):
|
459
716
|
self.started_transfer = False
|
460
717
|
self.conclude_state = None
|
461
|
-
super().__init__(mgr, bootstrap_addr, bootstrap_room,
|
718
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
719
|
+
|
720
|
+
# Track this room with its bootstrap address for heartbeat monitoring
|
721
|
+
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
|
722
|
+
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
|
723
|
+
self.bootstrap_room
|
724
|
+
)
|
462
725
|
|
463
726
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
464
727
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -494,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
494
757
|
|
495
758
|
self.kv_mgr.update_transfer_status()
|
496
759
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
497
|
-
|
760
|
+
# Check if the transfer failed
|
761
|
+
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
|
762
|
+
self.conclude_state = KVPoll.Failed
|
763
|
+
logger.error(
|
764
|
+
f"Transfer for room {self.bootstrap_room} failed due to node failure"
|
765
|
+
)
|
766
|
+
else:
|
767
|
+
self.conclude_state = KVPoll.Success
|
498
768
|
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
499
|
-
return
|
769
|
+
return self.conclude_state # type: ignore
|
500
770
|
return KVPoll.WaitingForInput # type: ignore
|
501
771
|
|
502
772
|
def _register_kv_args(self):
|
@@ -521,6 +791,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
521
791
|
packed_kv_data_ptrs,
|
522
792
|
packed_aux_data_ptrs,
|
523
793
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
794
|
+
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
795
|
+
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
796
|
+
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
524
797
|
]
|
525
798
|
)
|
526
799
|
|