sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -1,37 +1,30 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import asyncio
|
|
4
3
|
import dataclasses
|
|
5
4
|
import logging
|
|
6
|
-
import
|
|
7
|
-
import socket
|
|
5
|
+
import os
|
|
8
6
|
import struct
|
|
9
7
|
import threading
|
|
8
|
+
import time
|
|
10
9
|
import uuid
|
|
11
10
|
from collections import defaultdict
|
|
12
|
-
from
|
|
13
|
-
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
|
11
|
+
from typing import Dict, List, Optional, Set
|
|
14
12
|
|
|
15
13
|
import numpy as np
|
|
16
14
|
import numpy.typing as npt
|
|
17
15
|
import requests
|
|
18
|
-
import zmq
|
|
19
|
-
from aiohttp import web
|
|
20
16
|
|
|
21
|
-
from sglang.srt.disaggregation.base.conn import
|
|
17
|
+
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
|
22
18
|
from sglang.srt.disaggregation.common.conn import (
|
|
23
19
|
CommonKVBootstrapServer,
|
|
24
20
|
CommonKVManager,
|
|
25
21
|
CommonKVReceiver,
|
|
22
|
+
CommonKVSender,
|
|
26
23
|
)
|
|
27
24
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
|
28
25
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
29
26
|
from sglang.srt.server_args import ServerArgs
|
|
30
|
-
from sglang.srt.utils import
|
|
31
|
-
format_tcp_address,
|
|
32
|
-
get_local_ip_auto,
|
|
33
|
-
is_valid_ipv6_address,
|
|
34
|
-
)
|
|
27
|
+
from sglang.srt.utils import get_int_env_var
|
|
35
28
|
|
|
36
29
|
logger = logging.getLogger(__name__)
|
|
37
30
|
|
|
@@ -113,8 +106,14 @@ class TransferStatus:
|
|
|
113
106
|
def is_done(self):
|
|
114
107
|
if self.num_kvs_expected is None:
|
|
115
108
|
return False
|
|
109
|
+
# Check for failure state
|
|
110
|
+
if self.num_kvs_expected == -1:
|
|
111
|
+
return True # Failed transfers are considered "done"
|
|
116
112
|
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
|
117
113
|
|
|
114
|
+
def is_failed(self):
|
|
115
|
+
return self.num_kvs_expected == -1
|
|
116
|
+
|
|
118
117
|
|
|
119
118
|
class NixlKVManager(CommonKVManager):
|
|
120
119
|
def __init__(
|
|
@@ -134,26 +133,133 @@ class NixlKVManager(CommonKVManager):
|
|
|
134
133
|
"to run SGLang with NixlTransferEngine."
|
|
135
134
|
) from e
|
|
136
135
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
|
137
|
-
self.local_ip = get_local_ip_auto()
|
|
138
|
-
self.server_socket = zmq.Context().socket(zmq.PULL)
|
|
139
|
-
if is_valid_ipv6_address(self.local_ip):
|
|
140
|
-
self.server_socket.setsockopt(zmq.IPV6, 1)
|
|
141
136
|
self.register_buffer_to_engine()
|
|
142
137
|
|
|
143
138
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
|
144
|
-
self.request_status: Dict[int, KVPoll] = {}
|
|
145
|
-
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
|
146
|
-
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
|
147
139
|
self._start_bootstrap_thread()
|
|
148
140
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
|
149
141
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
|
150
142
|
TransferStatus
|
|
151
143
|
)
|
|
144
|
+
self.heartbeat_failures = {}
|
|
145
|
+
self.session_pool = defaultdict(requests.Session)
|
|
146
|
+
self.session_pool_lock = threading.Lock()
|
|
147
|
+
self.addr_to_rooms_tracker = defaultdict(set)
|
|
148
|
+
self.connection_lock = threading.Lock()
|
|
149
|
+
|
|
150
|
+
# Heartbeat interval should be at least 2 seconds
|
|
151
|
+
self.heartbeat_interval = max(
|
|
152
|
+
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
|
153
|
+
)
|
|
154
|
+
# Heartbeat failure should be at least 1
|
|
155
|
+
self.max_failures = max(
|
|
156
|
+
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
|
157
|
+
)
|
|
158
|
+
self._start_heartbeat_checker_thread()
|
|
152
159
|
else:
|
|
153
160
|
raise ValueError(
|
|
154
161
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
|
155
162
|
)
|
|
156
163
|
|
|
164
|
+
def _start_heartbeat_checker_thread(self):
|
|
165
|
+
"""
|
|
166
|
+
Start the heartbeat checker thread for Decode worker.
|
|
167
|
+
TODO (smor): unite nixl heartbeat checker with mooncake's.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def heartbeat_checker():
|
|
171
|
+
while True:
|
|
172
|
+
time.sleep(self.heartbeat_interval)
|
|
173
|
+
with self.connection_lock:
|
|
174
|
+
addresses = list(self.prefill_dp_size_table.keys())
|
|
175
|
+
|
|
176
|
+
for bootstrap_addr in addresses:
|
|
177
|
+
session = None
|
|
178
|
+
try:
|
|
179
|
+
with self.session_pool_lock:
|
|
180
|
+
session = self.session_pool[bootstrap_addr]
|
|
181
|
+
response = session.get(
|
|
182
|
+
f"http://{bootstrap_addr}/health",
|
|
183
|
+
timeout=(2, 3),
|
|
184
|
+
headers={"Connection": "keep-alive"},
|
|
185
|
+
)
|
|
186
|
+
if response.status_code == 200:
|
|
187
|
+
self.heartbeat_failures[bootstrap_addr] = 0
|
|
188
|
+
|
|
189
|
+
current_rooms = self.addr_to_rooms_tracker[
|
|
190
|
+
bootstrap_addr
|
|
191
|
+
].copy()
|
|
192
|
+
|
|
193
|
+
for bootstrap_room in current_rooms:
|
|
194
|
+
# Remove successful transfers from the tracker
|
|
195
|
+
if bootstrap_room not in self.transfer_statuses:
|
|
196
|
+
self.addr_to_rooms_tracker[bootstrap_addr].discard(
|
|
197
|
+
bootstrap_room
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
logger.info(
|
|
201
|
+
f"Attempting to reconnect to {bootstrap_addr}..."
|
|
202
|
+
)
|
|
203
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
|
204
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
|
205
|
+
)
|
|
206
|
+
with self.session_pool_lock:
|
|
207
|
+
if bootstrap_addr in self.session_pool:
|
|
208
|
+
del self.session_pool[bootstrap_addr]
|
|
209
|
+
except Exception:
|
|
210
|
+
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
|
|
211
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
|
212
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
if (
|
|
216
|
+
self.heartbeat_failures.get(bootstrap_addr, 0)
|
|
217
|
+
>= self.max_failures
|
|
218
|
+
):
|
|
219
|
+
self._handle_node_failure(bootstrap_addr)
|
|
220
|
+
with self.session_pool_lock:
|
|
221
|
+
if bootstrap_addr in self.session_pool:
|
|
222
|
+
del self.session_pool[bootstrap_addr]
|
|
223
|
+
|
|
224
|
+
threading.Thread(target=heartbeat_checker, daemon=True).start()
|
|
225
|
+
|
|
226
|
+
def _handle_node_failure(self, failed_bootstrap_addr):
|
|
227
|
+
"""Handle failure of a prefill node."""
|
|
228
|
+
with self.connection_lock:
|
|
229
|
+
keys_to_remove = [
|
|
230
|
+
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
|
|
231
|
+
]
|
|
232
|
+
for k in keys_to_remove:
|
|
233
|
+
del self.connection_pool[k]
|
|
234
|
+
if failed_bootstrap_addr in self.prefill_tp_size_table:
|
|
235
|
+
del self.prefill_tp_size_table[failed_bootstrap_addr]
|
|
236
|
+
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
|
237
|
+
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
|
238
|
+
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
|
239
|
+
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
|
240
|
+
|
|
241
|
+
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
|
242
|
+
failed_bootstrap_addr, []
|
|
243
|
+
)
|
|
244
|
+
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
|
|
245
|
+
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
|
|
246
|
+
|
|
247
|
+
# Mark all pending transfers associated with the failed node as failed
|
|
248
|
+
affected_rooms = []
|
|
249
|
+
for room in possible_affected_rooms:
|
|
250
|
+
if (
|
|
251
|
+
room in self.transfer_statuses
|
|
252
|
+
and not self.transfer_statuses[room].is_done()
|
|
253
|
+
):
|
|
254
|
+
# Mark the transfer as failed by setting a special state
|
|
255
|
+
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
|
|
256
|
+
affected_rooms.append(room)
|
|
257
|
+
|
|
258
|
+
logger.error(
|
|
259
|
+
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
|
|
260
|
+
f"{len(affected_rooms)} transfers affected"
|
|
261
|
+
)
|
|
262
|
+
|
|
157
263
|
def check_status(self, bootstrap_room: int):
|
|
158
264
|
return self.request_status[bootstrap_room]
|
|
159
265
|
|
|
@@ -166,6 +272,9 @@ class NixlKVManager(CommonKVManager):
|
|
|
166
272
|
self.request_status[bootstrap_room], status
|
|
167
273
|
)
|
|
168
274
|
|
|
275
|
+
def record_failure(self, bootstrap_room: int, failure_reason: str):
|
|
276
|
+
pass
|
|
277
|
+
|
|
169
278
|
def register_buffer_to_engine(self):
|
|
170
279
|
kv_addrs = []
|
|
171
280
|
for kv_data_ptr, kv_data_len in zip(
|
|
@@ -210,14 +319,44 @@ class NixlKVManager(CommonKVManager):
|
|
|
210
319
|
|
|
211
320
|
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
|
212
321
|
# Make descs
|
|
213
|
-
|
|
322
|
+
if self.is_mla_backend:
|
|
323
|
+
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
|
324
|
+
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
|
325
|
+
)
|
|
326
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
|
327
|
+
layers_params = [
|
|
328
|
+
(
|
|
329
|
+
src_kv_ptrs[layer_id],
|
|
330
|
+
dst_kv_ptrs[layer_id],
|
|
331
|
+
kv_item_len,
|
|
332
|
+
)
|
|
333
|
+
for layer_id in range(layers_current_pp_stage)
|
|
334
|
+
]
|
|
335
|
+
else:
|
|
336
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
|
337
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
kv_item_len = self.kv_args.kv_item_lens[0]
|
|
341
|
+
layers_params = [
|
|
342
|
+
(
|
|
343
|
+
src_k_ptrs[layer_id],
|
|
344
|
+
dst_k_ptrs[layer_id],
|
|
345
|
+
kv_item_len,
|
|
346
|
+
)
|
|
347
|
+
for layer_id in range(layers_current_pp_stage)
|
|
348
|
+
] + [
|
|
349
|
+
(
|
|
350
|
+
src_v_ptrs[layer_id],
|
|
351
|
+
dst_v_ptrs[layer_id],
|
|
352
|
+
kv_item_len,
|
|
353
|
+
)
|
|
354
|
+
for layer_id in range(layers_current_pp_stage)
|
|
355
|
+
]
|
|
356
|
+
|
|
214
357
|
src_addrs = []
|
|
215
358
|
dst_addrs = []
|
|
216
|
-
for
|
|
217
|
-
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
|
218
|
-
dst_ptr = dst_kv_ptrs[layer_id]
|
|
219
|
-
item_len = self.kv_args.kv_item_lens[layer_id]
|
|
220
|
-
|
|
359
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
|
221
360
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
|
222
361
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
|
223
362
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
|
@@ -288,6 +427,9 @@ class NixlKVManager(CommonKVManager):
|
|
|
288
427
|
num_heads_to_send = dst_heads_per_rank
|
|
289
428
|
dst_head_start_offset = 0
|
|
290
429
|
|
|
430
|
+
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
|
431
|
+
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
|
432
|
+
)
|
|
291
433
|
# Create transfer descriptors
|
|
292
434
|
src_addrs = []
|
|
293
435
|
dst_addrs = []
|
|
@@ -295,12 +437,6 @@ class NixlKVManager(CommonKVManager):
|
|
|
295
437
|
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
|
296
438
|
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
|
297
439
|
|
|
298
|
-
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
|
299
|
-
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
|
300
|
-
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
|
301
|
-
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
|
302
|
-
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
|
303
|
-
|
|
304
440
|
# Calculate precise byte offset and length for the sub-slice within the token
|
|
305
441
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
|
306
442
|
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
|
@@ -311,13 +447,13 @@ class NixlKVManager(CommonKVManager):
|
|
|
311
447
|
src_k_ptrs[layer_id],
|
|
312
448
|
dst_k_ptrs[layer_id],
|
|
313
449
|
)
|
|
314
|
-
for layer_id in range(
|
|
450
|
+
for layer_id in range(layers_current_pp_stage)
|
|
315
451
|
] + [
|
|
316
452
|
(
|
|
317
453
|
src_v_ptrs[layer_id],
|
|
318
454
|
dst_v_ptrs[layer_id],
|
|
319
455
|
)
|
|
320
|
-
for layer_id in range(
|
|
456
|
+
for layer_id in range(layers_current_pp_stage)
|
|
321
457
|
]
|
|
322
458
|
|
|
323
459
|
src_addrs = []
|
|
@@ -387,14 +523,19 @@ class NixlKVManager(CommonKVManager):
|
|
|
387
523
|
dst_aux_index: int,
|
|
388
524
|
notif: str,
|
|
389
525
|
):
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
526
|
+
src_addrs = []
|
|
527
|
+
dst_addrs = []
|
|
528
|
+
|
|
529
|
+
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
|
530
|
+
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
|
531
|
+
|
|
532
|
+
for i, _ in enumerate(dst_aux_ptrs):
|
|
533
|
+
length = prefill_aux_item_lens[i]
|
|
534
|
+
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
|
535
|
+
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
|
536
|
+
src_addrs.append((src_addr, length, 0))
|
|
537
|
+
dst_addrs.append((dst_addr, length, 0))
|
|
538
|
+
|
|
398
539
|
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
|
399
540
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
|
400
541
|
# Transfer data
|
|
@@ -438,7 +579,7 @@ class NixlKVManager(CommonKVManager):
|
|
|
438
579
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
|
439
580
|
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
|
440
581
|
|
|
441
|
-
if decode_tp_size == self.
|
|
582
|
+
if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
|
|
442
583
|
kv_xfer_handle = self.send_kvcache(
|
|
443
584
|
req.agent_name,
|
|
444
585
|
kv_indices,
|
|
@@ -455,7 +596,7 @@ class NixlKVManager(CommonKVManager):
|
|
|
455
596
|
chunked_dst_kv_indice,
|
|
456
597
|
self.decode_kv_args_table[req.agent_name].gpu_id,
|
|
457
598
|
notif,
|
|
458
|
-
prefill_tp_size=self.
|
|
599
|
+
prefill_tp_size=self.attn_tp_size,
|
|
459
600
|
decode_tp_size=decode_tp_size,
|
|
460
601
|
decode_tp_rank=self.decode_kv_args_table[
|
|
461
602
|
req.agent_name
|
|
@@ -467,7 +608,7 @@ class NixlKVManager(CommonKVManager):
|
|
|
467
608
|
|
|
468
609
|
handles.append(kv_xfer_handle)
|
|
469
610
|
# Only the last chunk we need to send the aux data.
|
|
470
|
-
if is_last:
|
|
611
|
+
if is_last and self.pp_group.is_last_rank:
|
|
471
612
|
assert aux_index is not None
|
|
472
613
|
aux_xfer_handle = self.send_aux(
|
|
473
614
|
req.agent_name,
|
|
@@ -505,9 +646,6 @@ class NixlKVManager(CommonKVManager):
|
|
|
505
646
|
return False
|
|
506
647
|
return self.transfer_statuses[room].is_done()
|
|
507
648
|
|
|
508
|
-
def _bind_server_socket(self):
|
|
509
|
-
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
|
510
|
-
|
|
511
649
|
def _start_bootstrap_thread(self):
|
|
512
650
|
self._bind_server_socket()
|
|
513
651
|
|
|
@@ -548,7 +686,7 @@ class NixlKVManager(CommonKVManager):
|
|
|
548
686
|
threading.Thread(target=bootstrap_thread).start()
|
|
549
687
|
|
|
550
688
|
|
|
551
|
-
class NixlKVSender(
|
|
689
|
+
class NixlKVSender(CommonKVSender):
|
|
552
690
|
|
|
553
691
|
def __init__(
|
|
554
692
|
self,
|
|
@@ -558,24 +696,15 @@ class NixlKVSender(BaseKVSender):
|
|
|
558
696
|
dest_tp_ranks: List[int],
|
|
559
697
|
pp_rank: int,
|
|
560
698
|
):
|
|
561
|
-
|
|
562
|
-
self.bootstrap_room = bootstrap_room
|
|
563
|
-
self.aux_index = None
|
|
564
|
-
self.bootstrap_server_url = bootstrap_addr
|
|
699
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
|
|
565
700
|
self.xfer_handles = []
|
|
566
701
|
self.has_sent = False
|
|
567
702
|
self.chunk_id = 0
|
|
568
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
|
569
|
-
# inner state
|
|
570
|
-
self.curr_idx = 0
|
|
571
|
-
|
|
572
|
-
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
|
573
|
-
self.num_kv_indices = num_kv_indices
|
|
574
|
-
self.aux_index = aux_index
|
|
575
703
|
|
|
576
704
|
def send(
|
|
577
705
|
self,
|
|
578
706
|
kv_indices: npt.NDArray[np.int32],
|
|
707
|
+
state_indices: Optional[List[int]] = None,
|
|
579
708
|
):
|
|
580
709
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
|
581
710
|
self.curr_idx += len(kv_indices)
|
|
@@ -621,7 +750,25 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
|
621
750
|
self.conclude_state = None
|
|
622
751
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
|
623
752
|
|
|
624
|
-
|
|
753
|
+
# Track this room with its bootstrap address for heartbeat monitoring
|
|
754
|
+
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
|
|
755
|
+
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
|
|
756
|
+
self.bootstrap_room
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
def init(
|
|
760
|
+
self,
|
|
761
|
+
kv_indices: npt.NDArray[np.int32],
|
|
762
|
+
aux_index: Optional[int] = None,
|
|
763
|
+
state_indices: Optional[List[int]] = None,
|
|
764
|
+
):
|
|
765
|
+
if self.bootstrap_infos is None:
|
|
766
|
+
logger.error(
|
|
767
|
+
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
|
|
768
|
+
)
|
|
769
|
+
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
|
|
770
|
+
return
|
|
771
|
+
|
|
625
772
|
for bootstrap_info in self.bootstrap_infos:
|
|
626
773
|
logger.debug(
|
|
627
774
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
|
@@ -655,9 +802,16 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
|
655
802
|
|
|
656
803
|
self.kv_mgr.update_transfer_status()
|
|
657
804
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
|
658
|
-
|
|
805
|
+
# Check if the transfer failed
|
|
806
|
+
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
|
|
807
|
+
self.conclude_state = KVPoll.Failed
|
|
808
|
+
logger.error(
|
|
809
|
+
f"Transfer for room {self.bootstrap_room} failed due to node failure"
|
|
810
|
+
)
|
|
811
|
+
else:
|
|
812
|
+
self.conclude_state = KVPoll.Success
|
|
659
813
|
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
|
660
|
-
return
|
|
814
|
+
return self.conclude_state # type: ignore
|
|
661
815
|
return KVPoll.WaitingForInput # type: ignore
|
|
662
816
|
|
|
663
817
|
def _register_kv_args(self):
|