sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import dataclasses
|
5
4
|
import logging
|
6
|
-
import
|
7
|
-
import socket
|
5
|
+
import os
|
8
6
|
import struct
|
9
7
|
import threading
|
8
|
+
import time
|
10
9
|
import uuid
|
11
10
|
from collections import defaultdict
|
12
|
-
from
|
13
|
-
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
11
|
+
from typing import Dict, List, Optional, Set
|
14
12
|
|
15
13
|
import numpy as np
|
16
14
|
import numpy.typing as npt
|
17
15
|
import requests
|
18
|
-
import zmq
|
19
|
-
from aiohttp import web
|
20
16
|
|
21
|
-
from sglang.srt.disaggregation.base.conn import
|
17
|
+
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
22
18
|
from sglang.srt.disaggregation.common.conn import (
|
23
19
|
CommonKVBootstrapServer,
|
24
20
|
CommonKVManager,
|
25
21
|
CommonKVReceiver,
|
22
|
+
CommonKVSender,
|
26
23
|
)
|
27
24
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
25
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
26
|
from sglang.srt.server_args import ServerArgs
|
30
|
-
from sglang.srt.utils import
|
31
|
-
format_tcp_address,
|
32
|
-
get_local_ip_auto,
|
33
|
-
is_valid_ipv6_address,
|
34
|
-
)
|
27
|
+
from sglang.srt.utils import get_int_env_var
|
35
28
|
|
36
29
|
logger = logging.getLogger(__name__)
|
37
30
|
|
@@ -78,6 +71,9 @@ class KVArgsRegisterInfo:
|
|
78
71
|
dst_kv_ptrs: list[int]
|
79
72
|
dst_aux_ptrs: list[int]
|
80
73
|
gpu_id: int
|
74
|
+
decode_tp_size: int
|
75
|
+
decode_tp_rank: int
|
76
|
+
dst_kv_item_len: int
|
81
77
|
|
82
78
|
@classmethod
|
83
79
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -90,6 +86,9 @@ class KVArgsRegisterInfo:
|
|
90
86
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
91
87
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
92
88
|
gpu_id=int(msg[7].decode("ascii")),
|
89
|
+
decode_tp_size=int(msg[8].decode("ascii")),
|
90
|
+
decode_tp_rank=int(msg[9].decode("ascii")),
|
91
|
+
dst_kv_item_len=int(msg[10].decode("ascii")),
|
93
92
|
)
|
94
93
|
|
95
94
|
|
@@ -107,8 +106,14 @@ class TransferStatus:
|
|
107
106
|
def is_done(self):
|
108
107
|
if self.num_kvs_expected is None:
|
109
108
|
return False
|
109
|
+
# Check for failure state
|
110
|
+
if self.num_kvs_expected == -1:
|
111
|
+
return True # Failed transfers are considered "done"
|
110
112
|
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
111
113
|
|
114
|
+
def is_failed(self):
|
115
|
+
return self.num_kvs_expected == -1
|
116
|
+
|
112
117
|
|
113
118
|
class NixlKVManager(CommonKVManager):
|
114
119
|
def __init__(
|
@@ -128,26 +133,133 @@ class NixlKVManager(CommonKVManager):
|
|
128
133
|
"to run SGLang with NixlTransferEngine."
|
129
134
|
) from e
|
130
135
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
131
|
-
self.local_ip = get_local_ip_auto()
|
132
|
-
self.server_socket = zmq.Context().socket(zmq.PULL)
|
133
|
-
if is_valid_ipv6_address(self.local_ip):
|
134
|
-
self.server_socket.setsockopt(zmq.IPV6, 1)
|
135
136
|
self.register_buffer_to_engine()
|
136
137
|
|
137
138
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
138
|
-
self.request_status: Dict[int, KVPoll] = {}
|
139
|
-
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
140
|
-
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
141
139
|
self._start_bootstrap_thread()
|
142
140
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
143
141
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
144
142
|
TransferStatus
|
145
143
|
)
|
144
|
+
self.heartbeat_failures = {}
|
145
|
+
self.session_pool = defaultdict(requests.Session)
|
146
|
+
self.session_pool_lock = threading.Lock()
|
147
|
+
self.addr_to_rooms_tracker = defaultdict(set)
|
148
|
+
self.connection_lock = threading.Lock()
|
149
|
+
|
150
|
+
# Heartbeat interval should be at least 2 seconds
|
151
|
+
self.heartbeat_interval = max(
|
152
|
+
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
153
|
+
)
|
154
|
+
# Heartbeat failure should be at least 1
|
155
|
+
self.max_failures = max(
|
156
|
+
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
157
|
+
)
|
158
|
+
self._start_heartbeat_checker_thread()
|
146
159
|
else:
|
147
160
|
raise ValueError(
|
148
161
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
149
162
|
)
|
150
163
|
|
164
|
+
def _start_heartbeat_checker_thread(self):
|
165
|
+
"""
|
166
|
+
Start the heartbeat checker thread for Decode worker.
|
167
|
+
TODO (smor): unite nixl heartbeat checker with mooncake's.
|
168
|
+
"""
|
169
|
+
|
170
|
+
def heartbeat_checker():
|
171
|
+
while True:
|
172
|
+
time.sleep(self.heartbeat_interval)
|
173
|
+
with self.connection_lock:
|
174
|
+
addresses = list(self.prefill_dp_size_table.keys())
|
175
|
+
|
176
|
+
for bootstrap_addr in addresses:
|
177
|
+
session = None
|
178
|
+
try:
|
179
|
+
with self.session_pool_lock:
|
180
|
+
session = self.session_pool[bootstrap_addr]
|
181
|
+
response = session.get(
|
182
|
+
f"http://{bootstrap_addr}/health",
|
183
|
+
timeout=(2, 3),
|
184
|
+
headers={"Connection": "keep-alive"},
|
185
|
+
)
|
186
|
+
if response.status_code == 200:
|
187
|
+
self.heartbeat_failures[bootstrap_addr] = 0
|
188
|
+
|
189
|
+
current_rooms = self.addr_to_rooms_tracker[
|
190
|
+
bootstrap_addr
|
191
|
+
].copy()
|
192
|
+
|
193
|
+
for bootstrap_room in current_rooms:
|
194
|
+
# Remove successful transfers from the tracker
|
195
|
+
if bootstrap_room not in self.transfer_statuses:
|
196
|
+
self.addr_to_rooms_tracker[bootstrap_addr].discard(
|
197
|
+
bootstrap_room
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
logger.info(
|
201
|
+
f"Attempting to reconnect to {bootstrap_addr}..."
|
202
|
+
)
|
203
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
204
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
205
|
+
)
|
206
|
+
with self.session_pool_lock:
|
207
|
+
if bootstrap_addr in self.session_pool:
|
208
|
+
del self.session_pool[bootstrap_addr]
|
209
|
+
except Exception:
|
210
|
+
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
|
211
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
212
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
213
|
+
)
|
214
|
+
|
215
|
+
if (
|
216
|
+
self.heartbeat_failures.get(bootstrap_addr, 0)
|
217
|
+
>= self.max_failures
|
218
|
+
):
|
219
|
+
self._handle_node_failure(bootstrap_addr)
|
220
|
+
with self.session_pool_lock:
|
221
|
+
if bootstrap_addr in self.session_pool:
|
222
|
+
del self.session_pool[bootstrap_addr]
|
223
|
+
|
224
|
+
threading.Thread(target=heartbeat_checker, daemon=True).start()
|
225
|
+
|
226
|
+
def _handle_node_failure(self, failed_bootstrap_addr):
|
227
|
+
"""Handle failure of a prefill node."""
|
228
|
+
with self.connection_lock:
|
229
|
+
keys_to_remove = [
|
230
|
+
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
|
231
|
+
]
|
232
|
+
for k in keys_to_remove:
|
233
|
+
del self.connection_pool[k]
|
234
|
+
if failed_bootstrap_addr in self.prefill_tp_size_table:
|
235
|
+
del self.prefill_tp_size_table[failed_bootstrap_addr]
|
236
|
+
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
237
|
+
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
238
|
+
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
239
|
+
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
240
|
+
|
241
|
+
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
242
|
+
failed_bootstrap_addr, []
|
243
|
+
)
|
244
|
+
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
|
245
|
+
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
|
246
|
+
|
247
|
+
# Mark all pending transfers associated with the failed node as failed
|
248
|
+
affected_rooms = []
|
249
|
+
for room in possible_affected_rooms:
|
250
|
+
if (
|
251
|
+
room in self.transfer_statuses
|
252
|
+
and not self.transfer_statuses[room].is_done()
|
253
|
+
):
|
254
|
+
# Mark the transfer as failed by setting a special state
|
255
|
+
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
|
256
|
+
affected_rooms.append(room)
|
257
|
+
|
258
|
+
logger.error(
|
259
|
+
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
|
260
|
+
f"{len(affected_rooms)} transfers affected"
|
261
|
+
)
|
262
|
+
|
151
263
|
def check_status(self, bootstrap_room: int):
|
152
264
|
return self.request_status[bootstrap_room]
|
153
265
|
|
@@ -160,13 +272,16 @@ class NixlKVManager(CommonKVManager):
|
|
160
272
|
self.request_status[bootstrap_room], status
|
161
273
|
)
|
162
274
|
|
275
|
+
def record_failure(self, bootstrap_room: int, failure_reason: str):
|
276
|
+
pass
|
277
|
+
|
163
278
|
def register_buffer_to_engine(self):
|
164
279
|
kv_addrs = []
|
165
280
|
for kv_data_ptr, kv_data_len in zip(
|
166
281
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
167
282
|
):
|
168
283
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
169
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM"
|
284
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
170
285
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
171
286
|
if not self.kv_descs:
|
172
287
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -175,7 +290,7 @@ class NixlKVManager(CommonKVManager):
|
|
175
290
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
176
291
|
):
|
177
292
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
178
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM"
|
293
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
179
294
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
180
295
|
if not self.aux_descs:
|
181
296
|
raise Exception("NIXL memory registration failed for aux tensors")
|
@@ -204,14 +319,44 @@ class NixlKVManager(CommonKVManager):
|
|
204
319
|
|
205
320
|
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
206
321
|
# Make descs
|
207
|
-
|
322
|
+
if self.is_mla_backend:
|
323
|
+
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
324
|
+
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
325
|
+
)
|
326
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
327
|
+
layers_params = [
|
328
|
+
(
|
329
|
+
src_kv_ptrs[layer_id],
|
330
|
+
dst_kv_ptrs[layer_id],
|
331
|
+
kv_item_len,
|
332
|
+
)
|
333
|
+
for layer_id in range(layers_current_pp_stage)
|
334
|
+
]
|
335
|
+
else:
|
336
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
337
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
338
|
+
)
|
339
|
+
|
340
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
341
|
+
layers_params = [
|
342
|
+
(
|
343
|
+
src_k_ptrs[layer_id],
|
344
|
+
dst_k_ptrs[layer_id],
|
345
|
+
kv_item_len,
|
346
|
+
)
|
347
|
+
for layer_id in range(layers_current_pp_stage)
|
348
|
+
] + [
|
349
|
+
(
|
350
|
+
src_v_ptrs[layer_id],
|
351
|
+
dst_v_ptrs[layer_id],
|
352
|
+
kv_item_len,
|
353
|
+
)
|
354
|
+
for layer_id in range(layers_current_pp_stage)
|
355
|
+
]
|
356
|
+
|
208
357
|
src_addrs = []
|
209
358
|
dst_addrs = []
|
210
|
-
for
|
211
|
-
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
212
|
-
dst_ptr = dst_kv_ptrs[layer_id]
|
213
|
-
item_len = self.kv_args.kv_item_lens[layer_id]
|
214
|
-
|
359
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
215
360
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
216
361
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
217
362
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
@@ -222,8 +367,8 @@ class NixlKVManager(CommonKVManager):
|
|
222
367
|
logger.debug(
|
223
368
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
224
369
|
)
|
225
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM"
|
226
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM"
|
370
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
371
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
227
372
|
# Transfer data
|
228
373
|
xfer_handle = self.agent.initialize_xfer(
|
229
374
|
"WRITE",
|
@@ -239,6 +384,137 @@ class NixlKVManager(CommonKVManager):
|
|
239
384
|
raise Exception("KVSender failed to post transfer")
|
240
385
|
return xfer_handle
|
241
386
|
|
387
|
+
def send_kvcache_slice(
|
388
|
+
self,
|
389
|
+
peer_name: str,
|
390
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
391
|
+
dst_kv_ptrs: list[int],
|
392
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
393
|
+
dst_gpu_id: int,
|
394
|
+
notif: str,
|
395
|
+
prefill_tp_size: int,
|
396
|
+
decode_tp_size: int,
|
397
|
+
decode_tp_rank: int,
|
398
|
+
dst_kv_item_len: int,
|
399
|
+
):
|
400
|
+
# Get configuration from kv_args
|
401
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
402
|
+
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
403
|
+
num_kv_heads = self.kv_args.kv_head_num
|
404
|
+
|
405
|
+
# Calculate head distribution
|
406
|
+
src_heads_per_rank = num_kv_heads
|
407
|
+
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
408
|
+
|
409
|
+
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
410
|
+
page_size = self.kv_args.page_size
|
411
|
+
|
412
|
+
bytes_per_head_slice_to_send = (
|
413
|
+
dst_kv_item_len // page_size // dst_heads_per_rank
|
414
|
+
)
|
415
|
+
|
416
|
+
# Determine which heads to send
|
417
|
+
if prefill_tp_size > decode_tp_size:
|
418
|
+
# Multiple prefill ranks to one decode rank
|
419
|
+
src_head_start_offset = 0
|
420
|
+
num_heads_to_send = src_heads_per_rank
|
421
|
+
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
422
|
+
else:
|
423
|
+
# Send KVCache from 1 prefill instance to multiple decode instances
|
424
|
+
src_head_start_offset = (
|
425
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
426
|
+
) % src_heads_per_rank
|
427
|
+
num_heads_to_send = dst_heads_per_rank
|
428
|
+
dst_head_start_offset = 0
|
429
|
+
|
430
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
431
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
432
|
+
)
|
433
|
+
# Create transfer descriptors
|
434
|
+
src_addrs = []
|
435
|
+
dst_addrs = []
|
436
|
+
|
437
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
438
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
439
|
+
|
440
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
441
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
442
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
443
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
444
|
+
|
445
|
+
src_dst_ptr_pairs = [
|
446
|
+
(
|
447
|
+
src_k_ptrs[layer_id],
|
448
|
+
dst_k_ptrs[layer_id],
|
449
|
+
)
|
450
|
+
for layer_id in range(layers_current_pp_stage)
|
451
|
+
] + [
|
452
|
+
(
|
453
|
+
src_v_ptrs[layer_id],
|
454
|
+
dst_v_ptrs[layer_id],
|
455
|
+
)
|
456
|
+
for layer_id in range(layers_current_pp_stage)
|
457
|
+
]
|
458
|
+
|
459
|
+
src_addrs = []
|
460
|
+
dst_addrs = []
|
461
|
+
|
462
|
+
# Calculate strides for a single token slot
|
463
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
464
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
465
|
+
|
466
|
+
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
467
|
+
for i in range(len(prefill_kv_indices)):
|
468
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
469
|
+
decode_page_idx = int(dst_kv_indices[i])
|
470
|
+
|
471
|
+
# Get the starting addresses for the current src and dst pages
|
472
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
473
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
474
|
+
|
475
|
+
# Iterate through each valid token slot within the current page
|
476
|
+
for token_slot_in_page in range(page_size):
|
477
|
+
# Calculate the start address of the current token slot
|
478
|
+
src_token_slot_start_addr = (
|
479
|
+
src_page_start_addr
|
480
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
481
|
+
)
|
482
|
+
dst_token_slot_start_addr = (
|
483
|
+
dst_page_start_addr
|
484
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
485
|
+
)
|
486
|
+
|
487
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
488
|
+
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
489
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
490
|
+
|
491
|
+
src_addrs.append(
|
492
|
+
(
|
493
|
+
src_slice_addr,
|
494
|
+
heads_bytes_per_token_to_send,
|
495
|
+
self.kv_args.gpu_id,
|
496
|
+
)
|
497
|
+
)
|
498
|
+
dst_addrs.append(
|
499
|
+
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
500
|
+
)
|
501
|
+
|
502
|
+
# Use NIXL agent for transfer
|
503
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
504
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
505
|
+
|
506
|
+
xfer_handle = self.agent.initialize_xfer(
|
507
|
+
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
508
|
+
)
|
509
|
+
if not xfer_handle:
|
510
|
+
raise Exception("Failed to create sliced KV transfer")
|
511
|
+
|
512
|
+
state = self.agent.transfer(xfer_handle)
|
513
|
+
if state == "ERR":
|
514
|
+
raise Exception("Failed to post sliced KV transfer")
|
515
|
+
|
516
|
+
return xfer_handle
|
517
|
+
|
242
518
|
def send_aux(
|
243
519
|
self,
|
244
520
|
peer_name: str,
|
@@ -247,16 +523,21 @@ class NixlKVManager(CommonKVManager):
|
|
247
523
|
dst_aux_index: int,
|
248
524
|
notif: str,
|
249
525
|
):
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
526
|
+
src_addrs = []
|
527
|
+
dst_addrs = []
|
528
|
+
|
529
|
+
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
530
|
+
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
531
|
+
|
532
|
+
for i, _ in enumerate(dst_aux_ptrs):
|
533
|
+
length = prefill_aux_item_lens[i]
|
534
|
+
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
535
|
+
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
536
|
+
src_addrs.append((src_addr, length, 0))
|
537
|
+
dst_addrs.append((dst_addr, length, 0))
|
538
|
+
|
539
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
540
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
260
541
|
# Transfer data
|
261
542
|
xfer_handle = self.agent.initialize_xfer(
|
262
543
|
"WRITE",
|
@@ -296,17 +577,38 @@ class NixlKVManager(CommonKVManager):
|
|
296
577
|
assert req.agent_name in self.decode_kv_args_table
|
297
578
|
|
298
579
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
580
|
+
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
581
|
+
|
582
|
+
if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
|
583
|
+
kv_xfer_handle = self.send_kvcache(
|
584
|
+
req.agent_name,
|
585
|
+
kv_indices,
|
586
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
587
|
+
chunked_dst_kv_indice,
|
588
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
589
|
+
notif,
|
590
|
+
)
|
591
|
+
else:
|
592
|
+
kv_xfer_handle = self.send_kvcache_slice(
|
593
|
+
req.agent_name,
|
594
|
+
kv_indices,
|
595
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
596
|
+
chunked_dst_kv_indice,
|
597
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
598
|
+
notif,
|
599
|
+
prefill_tp_size=self.attn_tp_size,
|
600
|
+
decode_tp_size=decode_tp_size,
|
601
|
+
decode_tp_rank=self.decode_kv_args_table[
|
602
|
+
req.agent_name
|
603
|
+
].decode_tp_rank,
|
604
|
+
dst_kv_item_len=self.decode_kv_args_table[
|
605
|
+
req.agent_name
|
606
|
+
].dst_kv_item_len,
|
607
|
+
)
|
608
|
+
|
307
609
|
handles.append(kv_xfer_handle)
|
308
610
|
# Only the last chunk we need to send the aux data.
|
309
|
-
if is_last:
|
611
|
+
if is_last and self.pp_group.is_last_rank:
|
310
612
|
assert aux_index is not None
|
311
613
|
aux_xfer_handle = self.send_aux(
|
312
614
|
req.agent_name,
|
@@ -344,9 +646,6 @@ class NixlKVManager(CommonKVManager):
|
|
344
646
|
return False
|
345
647
|
return self.transfer_statuses[room].is_done()
|
346
648
|
|
347
|
-
def _bind_server_socket(self):
|
348
|
-
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
349
|
-
|
350
649
|
def _start_bootstrap_thread(self):
|
351
650
|
self._bind_server_socket()
|
352
651
|
|
@@ -387,7 +686,7 @@ class NixlKVManager(CommonKVManager):
|
|
387
686
|
threading.Thread(target=bootstrap_thread).start()
|
388
687
|
|
389
688
|
|
390
|
-
class NixlKVSender(
|
689
|
+
class NixlKVSender(CommonKVSender):
|
391
690
|
|
392
691
|
def __init__(
|
393
692
|
self,
|
@@ -397,20 +696,10 @@ class NixlKVSender(BaseKVSender):
|
|
397
696
|
dest_tp_ranks: List[int],
|
398
697
|
pp_rank: int,
|
399
698
|
):
|
400
|
-
|
401
|
-
self.bootstrap_room = bootstrap_room
|
402
|
-
self.aux_index = None
|
403
|
-
self.bootstrap_server_url = bootstrap_addr
|
699
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
|
404
700
|
self.xfer_handles = []
|
405
701
|
self.has_sent = False
|
406
702
|
self.chunk_id = 0
|
407
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
408
|
-
# inner state
|
409
|
-
self.curr_idx = 0
|
410
|
-
|
411
|
-
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
412
|
-
self.num_kv_indices = num_kv_indices
|
413
|
-
self.aux_index = aux_index
|
414
703
|
|
415
704
|
def send(
|
416
705
|
self,
|
@@ -454,11 +743,17 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
454
743
|
mgr: NixlKVManager,
|
455
744
|
bootstrap_addr: str,
|
456
745
|
bootstrap_room: Optional[int] = None,
|
457
|
-
|
746
|
+
prefill_dp_rank: Optional[int] = None,
|
458
747
|
):
|
459
748
|
self.started_transfer = False
|
460
749
|
self.conclude_state = None
|
461
|
-
super().__init__(mgr, bootstrap_addr, bootstrap_room,
|
750
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
751
|
+
|
752
|
+
# Track this room with its bootstrap address for heartbeat monitoring
|
753
|
+
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
|
754
|
+
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
|
755
|
+
self.bootstrap_room
|
756
|
+
)
|
462
757
|
|
463
758
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
464
759
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -494,9 +789,16 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
494
789
|
|
495
790
|
self.kv_mgr.update_transfer_status()
|
496
791
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
497
|
-
|
792
|
+
# Check if the transfer failed
|
793
|
+
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
|
794
|
+
self.conclude_state = KVPoll.Failed
|
795
|
+
logger.error(
|
796
|
+
f"Transfer for room {self.bootstrap_room} failed due to node failure"
|
797
|
+
)
|
798
|
+
else:
|
799
|
+
self.conclude_state = KVPoll.Success
|
498
800
|
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
499
|
-
return
|
801
|
+
return self.conclude_state # type: ignore
|
500
802
|
return KVPoll.WaitingForInput # type: ignore
|
501
803
|
|
502
804
|
def _register_kv_args(self):
|
@@ -521,6 +823,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
521
823
|
packed_kv_data_ptrs,
|
522
824
|
packed_aux_data_ptrs,
|
523
825
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
826
|
+
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
827
|
+
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
828
|
+
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
524
829
|
]
|
525
830
|
)
|
526
831
|
|