sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import logging
|
23
23
|
import threading
|
24
|
+
import time
|
24
25
|
from collections import deque
|
25
26
|
from http import HTTPStatus
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
27
|
+
from typing import TYPE_CHECKING, List, Optional, Type
|
27
28
|
|
28
29
|
import torch
|
29
30
|
|
@@ -42,7 +43,12 @@ from sglang.srt.disaggregation.utils import (
|
|
42
43
|
poll_and_all_reduce,
|
43
44
|
prepare_abort,
|
44
45
|
)
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
46
|
+
from sglang.srt.managers.schedule_batch import (
|
47
|
+
FINISH_LENGTH,
|
48
|
+
Req,
|
49
|
+
RequestStage,
|
50
|
+
ScheduleBatch,
|
51
|
+
)
|
46
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
47
53
|
from sglang.srt.utils import (
|
48
54
|
DynamicGradMode,
|
@@ -140,8 +146,10 @@ class PrefillBootstrapQueue:
|
|
140
146
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
141
147
|
kv_args.gpu_id = self.scheduler.gpu_id
|
142
148
|
|
143
|
-
kv_manager_class = get_kv_class(
|
144
|
-
|
149
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
150
|
+
self.transfer_backend, KVClassType.MANAGER
|
151
|
+
)
|
152
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
145
153
|
kv_args,
|
146
154
|
DisaggregationMode.PREFILL,
|
147
155
|
self.scheduler.server_args,
|
@@ -168,6 +176,7 @@ class PrefillBootstrapQueue:
|
|
168
176
|
pp_rank=self.pp_rank,
|
169
177
|
)
|
170
178
|
self._process_req(req)
|
179
|
+
req.add_latency(RequestStage.PREFILL_PREPARE)
|
171
180
|
self.queue.append(req)
|
172
181
|
|
173
182
|
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
@@ -254,8 +263,11 @@ class PrefillBootstrapQueue:
|
|
254
263
|
|
255
264
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
256
265
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
266
|
+
|
257
267
|
bootstrapped_reqs.append(req)
|
258
268
|
indices_to_remove.add(i)
|
269
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
270
|
+
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
259
271
|
|
260
272
|
self.queue = [
|
261
273
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
@@ -309,6 +321,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
309
321
|
self.result_queue = deque()
|
310
322
|
|
311
323
|
while True:
|
324
|
+
self.launch_last_batch_sample_if_needed()
|
325
|
+
|
312
326
|
recv_reqs = self.recv_requests()
|
313
327
|
self.process_input_requests(recv_reqs)
|
314
328
|
self.waiting_queue.extend(
|
@@ -324,21 +338,8 @@ class SchedulerDisaggregationPrefillMixin:
|
|
324
338
|
result = self.run_batch(batch)
|
325
339
|
self.result_queue.append((batch.copy(), result))
|
326
340
|
|
327
|
-
if self.last_batch is None:
|
328
|
-
# Create a dummy first batch to start the pipeline for overlap schedule.
|
329
|
-
# It is now used for triggering the sampling_info_done event.
|
330
|
-
tmp_batch = ScheduleBatch(
|
331
|
-
reqs=None,
|
332
|
-
forward_mode=ForwardMode.DUMMY_FIRST,
|
333
|
-
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
334
|
-
)
|
335
|
-
self.set_next_batch_sampling_info_done(tmp_batch)
|
336
|
-
|
337
341
|
if self.last_batch:
|
338
342
|
tmp_batch, tmp_result = self.result_queue.popleft()
|
339
|
-
tmp_batch.next_batch_sampling_info = (
|
340
|
-
self.tp_worker.cur_sampling_info if batch else None
|
341
|
-
)
|
342
343
|
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
343
344
|
|
344
345
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
@@ -356,7 +357,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
356
357
|
self: Scheduler,
|
357
358
|
batch: ScheduleBatch,
|
358
359
|
result: GenerationBatchResult,
|
359
|
-
launch_done: Optional[threading.Event] = None,
|
360
360
|
) -> None:
|
361
361
|
"""
|
362
362
|
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
@@ -367,41 +367,40 @@ class SchedulerDisaggregationPrefillMixin:
|
|
367
367
|
next_token_ids,
|
368
368
|
extend_input_len_per_req,
|
369
369
|
extend_logprob_start_len_per_req,
|
370
|
+
copy_done,
|
370
371
|
) = (
|
371
372
|
result.logits_output,
|
372
373
|
result.next_token_ids,
|
373
374
|
result.extend_input_len_per_req,
|
374
375
|
result.extend_logprob_start_len_per_req,
|
376
|
+
result.copy_done,
|
375
377
|
)
|
376
378
|
|
379
|
+
if copy_done is not None:
|
380
|
+
copy_done.synchronize()
|
381
|
+
|
377
382
|
logprob_pt = 0
|
378
383
|
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
379
|
-
|
380
|
-
|
381
|
-
logits_output
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
logits_output.next_token_logprobs.tolist()
|
390
|
-
)
|
391
|
-
if logits_output.input_token_logprobs is not None:
|
392
|
-
logits_output.input_token_logprobs = tuple(
|
393
|
-
logits_output.input_token_logprobs.tolist()
|
394
|
-
)
|
384
|
+
next_token_ids = result.next_token_ids.tolist()
|
385
|
+
if batch.return_logprob:
|
386
|
+
if logits_output.next_token_logprobs is not None:
|
387
|
+
logits_output.next_token_logprobs = (
|
388
|
+
logits_output.next_token_logprobs.tolist()
|
389
|
+
)
|
390
|
+
if logits_output.input_token_logprobs is not None:
|
391
|
+
logits_output.input_token_logprobs = tuple(
|
392
|
+
logits_output.input_token_logprobs.tolist()
|
393
|
+
)
|
395
394
|
|
396
395
|
hidden_state_offset = 0
|
397
396
|
for i, (req, next_token_id) in enumerate(
|
398
397
|
zip(batch.reqs, next_token_ids, strict=True)
|
399
398
|
):
|
400
|
-
req: Req
|
401
399
|
if req.is_chunked <= 0:
|
402
400
|
# There is no output_ids for prefill
|
403
401
|
req.output_ids.append(next_token_id)
|
404
402
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
403
|
+
req.add_latency(RequestStage.PREFILL_FORWARD)
|
405
404
|
self.disagg_prefill_inflight_queue.append(req)
|
406
405
|
if (
|
407
406
|
logits_output is not None
|
@@ -410,9 +409,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
410
409
|
last_hidden_index = (
|
411
410
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
412
411
|
)
|
413
|
-
req.
|
414
|
-
|
415
|
-
)
|
412
|
+
req.output_topk_p = batch.spec_info.topk_p[i]
|
413
|
+
req.output_topk_index = batch.spec_info.topk_index[i]
|
414
|
+
if self.spec_algorithm.is_eagle3():
|
415
|
+
req.hidden_states_tensor = (
|
416
|
+
batch.spec_info.hidden_states[i].cpu().clone()
|
417
|
+
)
|
418
|
+
else:
|
419
|
+
req.hidden_states_tensor = (
|
420
|
+
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
421
|
+
)
|
416
422
|
hidden_state_offset += extend_input_len_per_req[i]
|
417
423
|
else:
|
418
424
|
req.hidden_states_tensor = None
|
@@ -432,6 +438,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
432
438
|
)
|
433
439
|
logprob_pt += num_input_logprobs
|
434
440
|
self.send_kv_chunk(req, last_chunk=True)
|
441
|
+
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
|
435
442
|
|
436
443
|
if req.grammar is not None:
|
437
444
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
@@ -471,8 +478,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
471
478
|
if self.enable_overlap:
|
472
479
|
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
473
480
|
|
474
|
-
# We need to remove the sync in the following function for overlap schedule.
|
475
|
-
self.set_next_batch_sampling_info_done(batch)
|
476
481
|
self.maybe_send_health_check_signal()
|
477
482
|
|
478
483
|
def process_disagg_prefill_inflight_queue(
|
@@ -529,6 +534,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
529
534
|
else:
|
530
535
|
assert False, f"Unexpected polling state {poll=}"
|
531
536
|
|
537
|
+
for req in done_reqs:
|
538
|
+
req.time_stats.completion_time = time.perf_counter()
|
539
|
+
|
532
540
|
# Stream requests which have finished transfer
|
533
541
|
self.stream_output(
|
534
542
|
done_reqs,
|
@@ -537,6 +545,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
537
545
|
)
|
538
546
|
for req in done_reqs:
|
539
547
|
req: Req
|
548
|
+
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
|
540
549
|
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
541
550
|
req.metadata_buffer_index = -1
|
542
551
|
|
@@ -665,7 +674,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
665
674
|
self.running_mbs = [
|
666
675
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
667
676
|
]
|
668
|
-
bids = [None] * self.pp_size
|
669
677
|
pp_outputs: Optional[PPProxyTensors] = None
|
670
678
|
|
671
679
|
# Either success or failed
|
@@ -737,10 +745,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
737
745
|
# send the outputs to the next step
|
738
746
|
if self.pp_group.is_last_rank:
|
739
747
|
if self.cur_batch:
|
740
|
-
next_token_ids
|
741
|
-
result.next_token_ids,
|
742
|
-
result.bid,
|
743
|
-
)
|
748
|
+
next_token_ids = result.next_token_ids
|
744
749
|
pp_outputs = PPProxyTensors(
|
745
750
|
{
|
746
751
|
"next_token_ids": next_token_ids,
|
@@ -777,7 +782,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
777
782
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
778
783
|
extend_input_len_per_req=None,
|
779
784
|
extend_logprob_start_len_per_req=None,
|
780
|
-
bid=bids[next_mb_id],
|
781
785
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
782
786
|
)
|
783
787
|
self.process_batch_result_disagg_prefill(
|
@@ -794,8 +798,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
794
798
|
|
795
799
|
# carry the outputs to the next stage
|
796
800
|
if not self.pp_group.is_last_rank:
|
797
|
-
if self.cur_batch:
|
798
|
-
bids[mb_id] = result.bid
|
799
801
|
if pp_outputs:
|
800
802
|
# send the outputs from the last round to let the next stage worker run post processing
|
801
803
|
self.pp_group.send_tensor_dict(
|
@@ -814,8 +816,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
814
816
|
|
815
817
|
# send out proxy tensors to the next stage
|
816
818
|
if self.cur_batch:
|
819
|
+
# FIXME(lsyin): remove this assert
|
820
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
817
821
|
self.pp_group.send_tensor_dict(
|
818
|
-
result.pp_hidden_states_proxy_tensors,
|
822
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
819
823
|
all_gather_group=self.attn_tp_group,
|
820
824
|
)
|
821
825
|
|
@@ -1,21 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import dataclasses
|
4
3
|
import os
|
5
4
|
import random
|
6
|
-
import threading
|
7
|
-
import warnings
|
8
5
|
from collections import deque
|
9
6
|
from contextlib import nullcontext
|
10
7
|
from enum import Enum
|
11
|
-
from typing import TYPE_CHECKING,
|
8
|
+
from typing import TYPE_CHECKING, Optional, Type
|
12
9
|
|
13
10
|
import numpy as np
|
14
|
-
import requests
|
15
11
|
import torch
|
16
12
|
import torch.distributed as dist
|
17
13
|
|
18
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_npu
|
19
15
|
|
20
16
|
if TYPE_CHECKING:
|
21
17
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -89,7 +85,7 @@ class MetadataBuffers:
|
|
89
85
|
self,
|
90
86
|
size: int,
|
91
87
|
hidden_size: int,
|
92
|
-
|
88
|
+
hidden_states_dtype: torch.dtype,
|
93
89
|
max_top_logprobs_num: int = 128,
|
94
90
|
custom_mem_pool: torch.cuda.MemPool = None,
|
95
91
|
):
|
@@ -111,7 +107,9 @@ class MetadataBuffers:
|
|
111
107
|
# We transfer the metadata of first output token to decode
|
112
108
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
113
109
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
114
|
-
|
110
|
+
self.cached_tokens = torch.zeros(
|
111
|
+
(size, 16), dtype=torch.int32, device=device
|
112
|
+
)
|
115
113
|
self.output_token_logprobs_val = torch.zeros(
|
116
114
|
(size, 16), dtype=torch.float32, device=device
|
117
115
|
)
|
@@ -124,33 +122,49 @@ class MetadataBuffers:
|
|
124
122
|
self.output_top_logprobs_idx = torch.zeros(
|
125
123
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
126
124
|
)
|
125
|
+
# For PD + spec decode
|
126
|
+
self.output_topk_p = torch.zeros(
|
127
|
+
(size, 16), dtype=torch.float32, device=device
|
128
|
+
)
|
129
|
+
self.output_topk_index = torch.zeros(
|
130
|
+
(size, 16), dtype=torch.int64, device=device
|
131
|
+
)
|
127
132
|
self.output_hidden_states = torch.zeros(
|
128
|
-
(size, hidden_size), dtype=
|
133
|
+
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
129
134
|
)
|
130
135
|
|
131
136
|
def get_buf_infos(self):
|
132
137
|
ptrs = [
|
133
138
|
self.output_ids.data_ptr(),
|
139
|
+
self.cached_tokens.data_ptr(),
|
134
140
|
self.output_token_logprobs_val.data_ptr(),
|
135
141
|
self.output_token_logprobs_idx.data_ptr(),
|
136
142
|
self.output_top_logprobs_val.data_ptr(),
|
137
143
|
self.output_top_logprobs_idx.data_ptr(),
|
144
|
+
self.output_topk_p.data_ptr(),
|
145
|
+
self.output_topk_index.data_ptr(),
|
138
146
|
self.output_hidden_states.data_ptr(),
|
139
147
|
]
|
140
148
|
data_lens = [
|
141
149
|
self.output_ids.nbytes,
|
150
|
+
self.cached_tokens.nbytes,
|
142
151
|
self.output_token_logprobs_val.nbytes,
|
143
152
|
self.output_token_logprobs_idx.nbytes,
|
144
153
|
self.output_top_logprobs_val.nbytes,
|
145
154
|
self.output_top_logprobs_idx.nbytes,
|
155
|
+
self.output_topk_p.nbytes,
|
156
|
+
self.output_topk_index.nbytes,
|
146
157
|
self.output_hidden_states.nbytes,
|
147
158
|
]
|
148
159
|
item_lens = [
|
149
160
|
self.output_ids[0].nbytes,
|
161
|
+
self.cached_tokens[0].nbytes,
|
150
162
|
self.output_token_logprobs_val[0].nbytes,
|
151
163
|
self.output_token_logprobs_idx[0].nbytes,
|
152
164
|
self.output_top_logprobs_val[0].nbytes,
|
153
165
|
self.output_top_logprobs_idx[0].nbytes,
|
166
|
+
self.output_topk_p[0].nbytes,
|
167
|
+
self.output_topk_index[0].nbytes,
|
154
168
|
self.output_hidden_states[0].nbytes,
|
155
169
|
]
|
156
170
|
return ptrs, data_lens, item_lens
|
@@ -158,16 +172,20 @@ class MetadataBuffers:
|
|
158
172
|
def get_buf(self, idx: int):
|
159
173
|
return (
|
160
174
|
self.output_ids[idx],
|
175
|
+
self.cached_tokens[idx],
|
161
176
|
self.output_token_logprobs_val[idx],
|
162
177
|
self.output_token_logprobs_idx[idx],
|
163
178
|
self.output_top_logprobs_val[idx],
|
164
179
|
self.output_top_logprobs_idx[idx],
|
180
|
+
self.output_topk_p[idx],
|
181
|
+
self.output_topk_index[idx],
|
165
182
|
self.output_hidden_states[idx],
|
166
183
|
)
|
167
184
|
|
168
185
|
def set_buf(self, req: Req):
|
169
186
|
|
170
187
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
188
|
+
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
171
189
|
if req.return_logprob:
|
172
190
|
if req.output_token_logprobs_val: # not none or empty list
|
173
191
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
@@ -190,8 +208,17 @@ class MetadataBuffers:
|
|
190
208
|
] = torch.tensor(
|
191
209
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
192
210
|
)
|
193
|
-
#
|
211
|
+
# For PD + spec decode
|
194
212
|
if req.hidden_states_tensor is not None:
|
213
|
+
# speculative_eagle_topk should not be greater than 16 currently
|
214
|
+
topk = req.output_topk_p.size(0)
|
215
|
+
|
216
|
+
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
217
|
+
req.output_topk_p
|
218
|
+
)
|
219
|
+
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
220
|
+
req.output_topk_index
|
221
|
+
)
|
195
222
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
196
223
|
req.hidden_states_tensor
|
197
224
|
)
|
@@ -217,7 +244,9 @@ class KVClassType(Enum):
|
|
217
244
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
218
245
|
|
219
246
|
|
220
|
-
def get_kv_class(
|
247
|
+
def get_kv_class(
|
248
|
+
transfer_backend: TransferBackend, class_type: KVClassType
|
249
|
+
) -> Optional[Type]:
|
221
250
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
222
251
|
|
223
252
|
if transfer_backend == TransferBackend.MOONCAKE:
|
@@ -305,49 +334,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
305
334
|
return (num_kv_indices + page_size - 1) // page_size
|
306
335
|
|
307
336
|
|
308
|
-
#########################
|
309
|
-
# PDLB Registry
|
310
|
-
#########################
|
311
|
-
|
312
|
-
|
313
|
-
@dataclasses.dataclass
|
314
|
-
class PDRegistryRequest:
|
315
|
-
"""A request to register a machine itself to the LB."""
|
316
|
-
|
317
|
-
mode: str
|
318
|
-
registry_url: str
|
319
|
-
bootstrap_port: Optional[int] = None
|
320
|
-
|
321
|
-
def __post_init__(self):
|
322
|
-
if self.mode == "prefill" and self.bootstrap_port is None:
|
323
|
-
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
324
|
-
elif self.mode == "decode" and self.bootstrap_port is not None:
|
325
|
-
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
326
|
-
elif self.mode not in ["prefill", "decode"]:
|
327
|
-
raise ValueError(
|
328
|
-
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
329
|
-
)
|
330
|
-
|
331
|
-
|
332
|
-
def register_disaggregation_server(
|
333
|
-
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
334
|
-
):
|
335
|
-
boostrap_port = bootstrap_port if mode == "prefill" else None
|
336
|
-
registry_request = PDRegistryRequest(
|
337
|
-
mode=mode,
|
338
|
-
registry_url=f"http://{get_ip()}:{server_port}",
|
339
|
-
bootstrap_port=boostrap_port,
|
340
|
-
)
|
341
|
-
res = requests.post(
|
342
|
-
f"{pdlb_url}/register",
|
343
|
-
json=dataclasses.asdict(registry_request),
|
344
|
-
)
|
345
|
-
if res.status_code != 200:
|
346
|
-
warnings.warn(
|
347
|
-
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
348
|
-
)
|
349
|
-
|
350
|
-
|
351
337
|
#########################
|
352
338
|
# Misc
|
353
339
|
#########################
|
@@ -0,0 +1,16 @@
|
|
1
|
+
MiB = 1024 * 1024
|
2
|
+
|
3
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
4
|
+
9: {
|
5
|
+
2: 64 * MiB, # 64 MB
|
6
|
+
4: 32 * MiB, # 32 MB
|
7
|
+
6: 64 * MiB, # 64 MB
|
8
|
+
8: 64 * MiB, # 64 MB
|
9
|
+
},
|
10
|
+
10: {
|
11
|
+
2: 64 * MiB, # 64 MB
|
12
|
+
4: 32 * MiB, # 32 MB
|
13
|
+
6: 128 * MiB, # 128 MB
|
14
|
+
8: 128 * MiB, # 128 MB
|
15
|
+
},
|
16
|
+
}
|
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
|
18
18
|
|
19
19
|
from sglang.srt.utils import (
|
20
20
|
format_tcp_address,
|
21
|
-
|
21
|
+
get_local_ip_auto,
|
22
22
|
get_open_port,
|
23
23
|
is_valid_ipv6_address,
|
24
24
|
)
|
@@ -191,7 +191,9 @@ class MessageQueue:
|
|
191
191
|
self.n_remote_reader = n_remote_reader
|
192
192
|
|
193
193
|
if connect_ip is None:
|
194
|
-
connect_ip =
|
194
|
+
connect_ip = (
|
195
|
+
get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
|
196
|
+
)
|
195
197
|
|
196
198
|
context = Context()
|
197
199
|
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
2
|
+
import logging
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.distributed as dist
|
7
|
+
from torch.distributed import ProcessGroup
|
8
|
+
|
9
|
+
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
10
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
11
|
+
)
|
12
|
+
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
13
|
+
|
14
|
+
try:
|
15
|
+
import torch.distributed._symmetric_memory as torch_symm_mem
|
16
|
+
|
17
|
+
symm_mem_available = True
|
18
|
+
except ImportError:
|
19
|
+
symm_mem_available = False
|
20
|
+
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
_is_cuda = is_cuda()
|
25
|
+
_is_hip = is_hip()
|
26
|
+
|
27
|
+
symm_mem_is_available = False
|
28
|
+
if _is_hip:
|
29
|
+
symm_mem_is_available = False
|
30
|
+
if _is_cuda:
|
31
|
+
symm_mem_is_available = True
|
32
|
+
|
33
|
+
|
34
|
+
class SymmMemCommunicator:
|
35
|
+
"""
|
36
|
+
Thin wrapper around symmetric-memory collectives.
|
37
|
+
|
38
|
+
This communicator:
|
39
|
+
- Validates device capability and world size.
|
40
|
+
- Allocates a shared symmetric buffer.
|
41
|
+
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
42
|
+
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
43
|
+
|
44
|
+
If any prerequisite is not met, the instance remains disabled and will
|
45
|
+
decline to perform symmetric-memory all-reduce.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Mapping: compute capability major -> supported world sizes for multimem
|
49
|
+
# If the current (cc_major, world_size) is not listed, we fall back
|
50
|
+
# to the two-shot path.
|
51
|
+
_WORLD_SIZES_MULTIMEM = {
|
52
|
+
9: [4, 6, 8],
|
53
|
+
10: [6, 8],
|
54
|
+
}
|
55
|
+
|
56
|
+
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
57
|
+
"""
|
58
|
+
Args:
|
59
|
+
group: Torch process group used for rendezvous and naming.
|
60
|
+
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
61
|
+
"""
|
62
|
+
|
63
|
+
self.disabled = True
|
64
|
+
|
65
|
+
if not symm_mem_available:
|
66
|
+
return
|
67
|
+
|
68
|
+
if isinstance(device, int):
|
69
|
+
device = torch.device(f"cuda:{device}")
|
70
|
+
elif isinstance(device, str):
|
71
|
+
device = torch.device(device)
|
72
|
+
torch.cuda.set_device(device)
|
73
|
+
self.dtype = torch.bfloat16
|
74
|
+
self.device = device
|
75
|
+
self.group = group
|
76
|
+
self.world_size = dist.get_world_size(self.group)
|
77
|
+
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
78
|
+
if self.device_capability < 9:
|
79
|
+
logger.warning(
|
80
|
+
"SymmMemCommunicator: Device capability %s not supported, "
|
81
|
+
"communicator is not available.",
|
82
|
+
self.device_capability,
|
83
|
+
)
|
84
|
+
return
|
85
|
+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
86
|
+
logger.warning(
|
87
|
+
"SymmMemCommunicator: World size %d not supported, "
|
88
|
+
"communicator is not available.",
|
89
|
+
self.world_size,
|
90
|
+
)
|
91
|
+
return
|
92
|
+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
93
|
+
self.world_size
|
94
|
+
]
|
95
|
+
self.buffer = torch_symm_mem.empty(
|
96
|
+
self.max_size // self.dtype.itemsize,
|
97
|
+
device=self.device,
|
98
|
+
dtype=self.dtype,
|
99
|
+
)
|
100
|
+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
101
|
+
if handle.multicast_ptr == 0:
|
102
|
+
logger.warning(
|
103
|
+
"SymmMemCommunicator: symmetric memory "
|
104
|
+
"multicast operations are not supported."
|
105
|
+
)
|
106
|
+
self.buffer = None
|
107
|
+
self.disabled = True
|
108
|
+
return
|
109
|
+
self.disabled = False
|
110
|
+
|
111
|
+
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
112
|
+
"""
|
113
|
+
Fast-path eligibility check for a given tensor.
|
114
|
+
|
115
|
+
Conditions:
|
116
|
+
- Communicator must be enabled.
|
117
|
+
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
118
|
+
- Total byte size must be 4-byte aligned (hardware requirement).
|
119
|
+
- Payload must be smaller than the symmetric-memory max size.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
True if the symmetric-memory path can handle this tensor.
|
123
|
+
"""
|
124
|
+
if self.disabled:
|
125
|
+
return False
|
126
|
+
if inp.dtype != self.dtype:
|
127
|
+
return False
|
128
|
+
inp_size = inp.numel() * inp.element_size()
|
129
|
+
# enforce 4-byte alignment
|
130
|
+
if inp_size % 4 != 0:
|
131
|
+
return False
|
132
|
+
return inp_size < self.max_size
|
133
|
+
|
134
|
+
def all_reduce(
|
135
|
+
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
136
|
+
) -> Optional[torch.Tensor]:
|
137
|
+
"""
|
138
|
+
Perform an in-place sum all-reduce via symmetric memory.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
inp: Input tensor on the target CUDA device (bfloat16).
|
142
|
+
out: Optional output tensor; if omitted, a new tensor is allocated.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
The reduced tensor (same shape as inp), or None if disabled.
|
146
|
+
|
147
|
+
Implementation details:
|
148
|
+
- Stages 'inp' into the symmetric buffer.
|
149
|
+
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
150
|
+
- Writes the result into 'out' and returns it.
|
151
|
+
"""
|
152
|
+
if out is None:
|
153
|
+
out = torch.empty_like(inp)
|
154
|
+
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
155
|
+
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
156
|
+
torch.ops.symm_mem.multimem_all_reduce_(
|
157
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
torch.ops.symm_mem.two_shot_all_reduce_(
|
161
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
162
|
+
)
|
163
|
+
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
164
|
+
return out
|