sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -21,9 +21,10 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import logging
|
23
23
|
import threading
|
24
|
+
import time
|
24
25
|
from collections import deque
|
25
26
|
from http import HTTPStatus
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
27
|
+
from typing import TYPE_CHECKING, List, Optional, Type
|
27
28
|
|
28
29
|
import torch
|
29
30
|
|
@@ -42,7 +43,12 @@ from sglang.srt.disaggregation.utils import (
|
|
42
43
|
poll_and_all_reduce,
|
43
44
|
prepare_abort,
|
44
45
|
)
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
46
|
+
from sglang.srt.managers.schedule_batch import (
|
47
|
+
FINISH_LENGTH,
|
48
|
+
Req,
|
49
|
+
RequestStage,
|
50
|
+
ScheduleBatch,
|
51
|
+
)
|
46
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
47
53
|
from sglang.srt.utils import (
|
48
54
|
DynamicGradMode,
|
@@ -140,8 +146,10 @@ class PrefillBootstrapQueue:
|
|
140
146
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
141
147
|
kv_args.gpu_id = self.scheduler.gpu_id
|
142
148
|
|
143
|
-
kv_manager_class = get_kv_class(
|
144
|
-
|
149
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
150
|
+
self.transfer_backend, KVClassType.MANAGER
|
151
|
+
)
|
152
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
145
153
|
kv_args,
|
146
154
|
DisaggregationMode.PREFILL,
|
147
155
|
self.scheduler.server_args,
|
@@ -168,6 +176,7 @@ class PrefillBootstrapQueue:
|
|
168
176
|
pp_rank=self.pp_rank,
|
169
177
|
)
|
170
178
|
self._process_req(req)
|
179
|
+
req.add_latency(RequestStage.PREFILL_PREPARE)
|
171
180
|
self.queue.append(req)
|
172
181
|
|
173
182
|
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
@@ -254,8 +263,11 @@ class PrefillBootstrapQueue:
|
|
254
263
|
|
255
264
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
256
265
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
266
|
+
|
257
267
|
bootstrapped_reqs.append(req)
|
258
268
|
indices_to_remove.add(i)
|
269
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
270
|
+
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
259
271
|
|
260
272
|
self.queue = [
|
261
273
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
@@ -397,11 +409,11 @@ class SchedulerDisaggregationPrefillMixin:
|
|
397
409
|
for i, (req, next_token_id) in enumerate(
|
398
410
|
zip(batch.reqs, next_token_ids, strict=True)
|
399
411
|
):
|
400
|
-
req: Req
|
401
412
|
if req.is_chunked <= 0:
|
402
413
|
# There is no output_ids for prefill
|
403
414
|
req.output_ids.append(next_token_id)
|
404
415
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
416
|
+
req.add_latency(RequestStage.PREFILL_FORWARD)
|
405
417
|
self.disagg_prefill_inflight_queue.append(req)
|
406
418
|
if (
|
407
419
|
logits_output is not None
|
@@ -410,9 +422,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
410
422
|
last_hidden_index = (
|
411
423
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
412
424
|
)
|
413
|
-
req.
|
414
|
-
|
415
|
-
)
|
425
|
+
req.output_topk_p = batch.spec_info.topk_p[i]
|
426
|
+
req.output_topk_index = batch.spec_info.topk_index[i]
|
427
|
+
if self.spec_algorithm.is_eagle3():
|
428
|
+
req.hidden_states_tensor = (
|
429
|
+
batch.spec_info.hidden_states[i].cpu().clone()
|
430
|
+
)
|
431
|
+
else:
|
432
|
+
req.hidden_states_tensor = (
|
433
|
+
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
434
|
+
)
|
416
435
|
hidden_state_offset += extend_input_len_per_req[i]
|
417
436
|
else:
|
418
437
|
req.hidden_states_tensor = None
|
@@ -432,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
432
451
|
)
|
433
452
|
logprob_pt += num_input_logprobs
|
434
453
|
self.send_kv_chunk(req, last_chunk=True)
|
454
|
+
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
|
435
455
|
|
436
456
|
if req.grammar is not None:
|
437
457
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
@@ -529,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
529
549
|
else:
|
530
550
|
assert False, f"Unexpected polling state {poll=}"
|
531
551
|
|
552
|
+
for req in done_reqs:
|
553
|
+
req.time_stats.completion_time = time.perf_counter()
|
554
|
+
|
532
555
|
# Stream requests which have finished transfer
|
533
556
|
self.stream_output(
|
534
557
|
done_reqs,
|
@@ -537,6 +560,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
537
560
|
)
|
538
561
|
for req in done_reqs:
|
539
562
|
req: Req
|
563
|
+
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
|
540
564
|
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
541
565
|
req.metadata_buffer_index = -1
|
542
566
|
|
@@ -665,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
665
689
|
self.running_mbs = [
|
666
690
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
667
691
|
]
|
668
|
-
bids = [None] * self.pp_size
|
669
692
|
pp_outputs: Optional[PPProxyTensors] = None
|
670
693
|
|
671
694
|
# Either success or failed
|
@@ -737,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
737
760
|
# send the outputs to the next step
|
738
761
|
if self.pp_group.is_last_rank:
|
739
762
|
if self.cur_batch:
|
740
|
-
next_token_ids
|
741
|
-
result.next_token_ids,
|
742
|
-
result.bid,
|
743
|
-
)
|
763
|
+
next_token_ids = result.next_token_ids
|
744
764
|
pp_outputs = PPProxyTensors(
|
745
765
|
{
|
746
766
|
"next_token_ids": next_token_ids,
|
@@ -777,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
777
797
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
778
798
|
extend_input_len_per_req=None,
|
779
799
|
extend_logprob_start_len_per_req=None,
|
780
|
-
bid=bids[next_mb_id],
|
781
800
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
782
801
|
)
|
783
802
|
self.process_batch_result_disagg_prefill(
|
@@ -794,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
794
813
|
|
795
814
|
# carry the outputs to the next stage
|
796
815
|
if not self.pp_group.is_last_rank:
|
797
|
-
if self.cur_batch:
|
798
|
-
bids[mb_id] = result.bid
|
799
816
|
if pp_outputs:
|
800
817
|
# send the outputs from the last round to let the next stage worker run post processing
|
801
818
|
self.pp_group.send_tensor_dict(
|
@@ -814,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
814
831
|
|
815
832
|
# send out proxy tensors to the next stage
|
816
833
|
if self.cur_batch:
|
834
|
+
# FIXME(lsyin): remove this assert
|
835
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
817
836
|
self.pp_group.send_tensor_dict(
|
818
|
-
result.pp_hidden_states_proxy_tensors,
|
837
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
819
838
|
all_gather_group=self.attn_tp_group,
|
820
839
|
)
|
821
840
|
|
@@ -1,21 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import dataclasses
|
4
3
|
import os
|
5
4
|
import random
|
6
|
-
import threading
|
7
|
-
import warnings
|
8
5
|
from collections import deque
|
9
6
|
from contextlib import nullcontext
|
10
7
|
from enum import Enum
|
11
|
-
from typing import TYPE_CHECKING,
|
8
|
+
from typing import TYPE_CHECKING, Optional, Type
|
12
9
|
|
13
10
|
import numpy as np
|
14
|
-
import requests
|
15
11
|
import torch
|
16
12
|
import torch.distributed as dist
|
17
13
|
|
18
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_npu
|
19
15
|
|
20
16
|
if TYPE_CHECKING:
|
21
17
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -89,7 +85,7 @@ class MetadataBuffers:
|
|
89
85
|
self,
|
90
86
|
size: int,
|
91
87
|
hidden_size: int,
|
92
|
-
|
88
|
+
hidden_states_dtype: torch.dtype,
|
93
89
|
max_top_logprobs_num: int = 128,
|
94
90
|
custom_mem_pool: torch.cuda.MemPool = None,
|
95
91
|
):
|
@@ -111,7 +107,9 @@ class MetadataBuffers:
|
|
111
107
|
# We transfer the metadata of first output token to decode
|
112
108
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
113
109
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
114
|
-
|
110
|
+
self.cached_tokens = torch.zeros(
|
111
|
+
(size, 16), dtype=torch.int32, device=device
|
112
|
+
)
|
115
113
|
self.output_token_logprobs_val = torch.zeros(
|
116
114
|
(size, 16), dtype=torch.float32, device=device
|
117
115
|
)
|
@@ -124,33 +122,49 @@ class MetadataBuffers:
|
|
124
122
|
self.output_top_logprobs_idx = torch.zeros(
|
125
123
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
126
124
|
)
|
125
|
+
# For PD + spec decode
|
126
|
+
self.output_topk_p = torch.zeros(
|
127
|
+
(size, 16), dtype=torch.float32, device=device
|
128
|
+
)
|
129
|
+
self.output_topk_index = torch.zeros(
|
130
|
+
(size, 16), dtype=torch.int64, device=device
|
131
|
+
)
|
127
132
|
self.output_hidden_states = torch.zeros(
|
128
|
-
(size, hidden_size), dtype=
|
133
|
+
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
129
134
|
)
|
130
135
|
|
131
136
|
def get_buf_infos(self):
|
132
137
|
ptrs = [
|
133
138
|
self.output_ids.data_ptr(),
|
139
|
+
self.cached_tokens.data_ptr(),
|
134
140
|
self.output_token_logprobs_val.data_ptr(),
|
135
141
|
self.output_token_logprobs_idx.data_ptr(),
|
136
142
|
self.output_top_logprobs_val.data_ptr(),
|
137
143
|
self.output_top_logprobs_idx.data_ptr(),
|
144
|
+
self.output_topk_p.data_ptr(),
|
145
|
+
self.output_topk_index.data_ptr(),
|
138
146
|
self.output_hidden_states.data_ptr(),
|
139
147
|
]
|
140
148
|
data_lens = [
|
141
149
|
self.output_ids.nbytes,
|
150
|
+
self.cached_tokens.nbytes,
|
142
151
|
self.output_token_logprobs_val.nbytes,
|
143
152
|
self.output_token_logprobs_idx.nbytes,
|
144
153
|
self.output_top_logprobs_val.nbytes,
|
145
154
|
self.output_top_logprobs_idx.nbytes,
|
155
|
+
self.output_topk_p.nbytes,
|
156
|
+
self.output_topk_index.nbytes,
|
146
157
|
self.output_hidden_states.nbytes,
|
147
158
|
]
|
148
159
|
item_lens = [
|
149
160
|
self.output_ids[0].nbytes,
|
161
|
+
self.cached_tokens[0].nbytes,
|
150
162
|
self.output_token_logprobs_val[0].nbytes,
|
151
163
|
self.output_token_logprobs_idx[0].nbytes,
|
152
164
|
self.output_top_logprobs_val[0].nbytes,
|
153
165
|
self.output_top_logprobs_idx[0].nbytes,
|
166
|
+
self.output_topk_p[0].nbytes,
|
167
|
+
self.output_topk_index[0].nbytes,
|
154
168
|
self.output_hidden_states[0].nbytes,
|
155
169
|
]
|
156
170
|
return ptrs, data_lens, item_lens
|
@@ -158,16 +172,20 @@ class MetadataBuffers:
|
|
158
172
|
def get_buf(self, idx: int):
|
159
173
|
return (
|
160
174
|
self.output_ids[idx],
|
175
|
+
self.cached_tokens[idx],
|
161
176
|
self.output_token_logprobs_val[idx],
|
162
177
|
self.output_token_logprobs_idx[idx],
|
163
178
|
self.output_top_logprobs_val[idx],
|
164
179
|
self.output_top_logprobs_idx[idx],
|
180
|
+
self.output_topk_p[idx],
|
181
|
+
self.output_topk_index[idx],
|
165
182
|
self.output_hidden_states[idx],
|
166
183
|
)
|
167
184
|
|
168
185
|
def set_buf(self, req: Req):
|
169
186
|
|
170
187
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
188
|
+
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
171
189
|
if req.return_logprob:
|
172
190
|
if req.output_token_logprobs_val: # not none or empty list
|
173
191
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
@@ -190,8 +208,17 @@ class MetadataBuffers:
|
|
190
208
|
] = torch.tensor(
|
191
209
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
192
210
|
)
|
193
|
-
#
|
211
|
+
# For PD + spec decode
|
194
212
|
if req.hidden_states_tensor is not None:
|
213
|
+
# speculative_eagle_topk should not be greater than 16 currently
|
214
|
+
topk = req.output_topk_p.size(0)
|
215
|
+
|
216
|
+
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
217
|
+
req.output_topk_p
|
218
|
+
)
|
219
|
+
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
220
|
+
req.output_topk_index
|
221
|
+
)
|
195
222
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
196
223
|
req.hidden_states_tensor
|
197
224
|
)
|
@@ -217,7 +244,9 @@ class KVClassType(Enum):
|
|
217
244
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
218
245
|
|
219
246
|
|
220
|
-
def get_kv_class(
|
247
|
+
def get_kv_class(
|
248
|
+
transfer_backend: TransferBackend, class_type: KVClassType
|
249
|
+
) -> Optional[Type]:
|
221
250
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
222
251
|
|
223
252
|
if transfer_backend == TransferBackend.MOONCAKE:
|
@@ -305,49 +334,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
305
334
|
return (num_kv_indices + page_size - 1) // page_size
|
306
335
|
|
307
336
|
|
308
|
-
#########################
|
309
|
-
# PDLB Registry
|
310
|
-
#########################
|
311
|
-
|
312
|
-
|
313
|
-
@dataclasses.dataclass
|
314
|
-
class PDRegistryRequest:
|
315
|
-
"""A request to register a machine itself to the LB."""
|
316
|
-
|
317
|
-
mode: str
|
318
|
-
registry_url: str
|
319
|
-
bootstrap_port: Optional[int] = None
|
320
|
-
|
321
|
-
def __post_init__(self):
|
322
|
-
if self.mode == "prefill" and self.bootstrap_port is None:
|
323
|
-
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
324
|
-
elif self.mode == "decode" and self.bootstrap_port is not None:
|
325
|
-
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
326
|
-
elif self.mode not in ["prefill", "decode"]:
|
327
|
-
raise ValueError(
|
328
|
-
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
329
|
-
)
|
330
|
-
|
331
|
-
|
332
|
-
def register_disaggregation_server(
|
333
|
-
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
334
|
-
):
|
335
|
-
boostrap_port = bootstrap_port if mode == "prefill" else None
|
336
|
-
registry_request = PDRegistryRequest(
|
337
|
-
mode=mode,
|
338
|
-
registry_url=f"http://{get_ip()}:{server_port}",
|
339
|
-
bootstrap_port=boostrap_port,
|
340
|
-
)
|
341
|
-
res = requests.post(
|
342
|
-
f"{pdlb_url}/register",
|
343
|
-
json=dataclasses.asdict(registry_request),
|
344
|
-
)
|
345
|
-
if res.status_code != 200:
|
346
|
-
warnings.warn(
|
347
|
-
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
348
|
-
)
|
349
|
-
|
350
|
-
|
351
337
|
#########################
|
352
338
|
# Misc
|
353
339
|
#########################
|
@@ -0,0 +1,16 @@
|
|
1
|
+
MiB = 1024 * 1024
|
2
|
+
|
3
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
4
|
+
9: {
|
5
|
+
2: 64 * MiB, # 64 MB
|
6
|
+
4: 32 * MiB, # 32 MB
|
7
|
+
6: 64 * MiB, # 64 MB
|
8
|
+
8: 64 * MiB, # 64 MB
|
9
|
+
},
|
10
|
+
10: {
|
11
|
+
2: 64 * MiB, # 64 MB
|
12
|
+
4: 32 * MiB, # 32 MB
|
13
|
+
6: 128 * MiB, # 128 MB
|
14
|
+
8: 128 * MiB, # 128 MB
|
15
|
+
},
|
16
|
+
}
|
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
|
18
18
|
|
19
19
|
from sglang.srt.utils import (
|
20
20
|
format_tcp_address,
|
21
|
-
|
21
|
+
get_local_ip_auto,
|
22
22
|
get_open_port,
|
23
23
|
is_valid_ipv6_address,
|
24
24
|
)
|
@@ -191,7 +191,9 @@ class MessageQueue:
|
|
191
191
|
self.n_remote_reader = n_remote_reader
|
192
192
|
|
193
193
|
if connect_ip is None:
|
194
|
-
connect_ip =
|
194
|
+
connect_ip = (
|
195
|
+
get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
|
196
|
+
)
|
195
197
|
|
196
198
|
context = Context()
|
197
199
|
|
@@ -0,0 +1,164 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
2
|
+
import logging
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.distributed as dist
|
7
|
+
from torch.distributed import ProcessGroup
|
8
|
+
|
9
|
+
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
10
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
11
|
+
)
|
12
|
+
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
13
|
+
|
14
|
+
try:
|
15
|
+
import torch.distributed._symmetric_memory as torch_symm_mem
|
16
|
+
|
17
|
+
symm_mem_available = True
|
18
|
+
except ImportError:
|
19
|
+
symm_mem_available = False
|
20
|
+
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
_is_cuda = is_cuda()
|
25
|
+
_is_hip = is_hip()
|
26
|
+
|
27
|
+
symm_mem_is_available = False
|
28
|
+
if _is_hip:
|
29
|
+
symm_mem_is_available = False
|
30
|
+
if _is_cuda:
|
31
|
+
symm_mem_is_available = True
|
32
|
+
|
33
|
+
|
34
|
+
class SymmMemCommunicator:
|
35
|
+
"""
|
36
|
+
Thin wrapper around symmetric-memory collectives.
|
37
|
+
|
38
|
+
This communicator:
|
39
|
+
- Validates device capability and world size.
|
40
|
+
- Allocates a shared symmetric buffer.
|
41
|
+
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
42
|
+
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
43
|
+
|
44
|
+
If any prerequisite is not met, the instance remains disabled and will
|
45
|
+
decline to perform symmetric-memory all-reduce.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Mapping: compute capability major -> supported world sizes for multimem
|
49
|
+
# If the current (cc_major, world_size) is not listed, we fall back
|
50
|
+
# to the two-shot path.
|
51
|
+
_WORLD_SIZES_MULTIMEM = {
|
52
|
+
9: [4, 6, 8],
|
53
|
+
10: [6, 8],
|
54
|
+
}
|
55
|
+
|
56
|
+
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
57
|
+
"""
|
58
|
+
Args:
|
59
|
+
group: Torch process group used for rendezvous and naming.
|
60
|
+
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
61
|
+
"""
|
62
|
+
|
63
|
+
self.disabled = True
|
64
|
+
|
65
|
+
if not symm_mem_available:
|
66
|
+
return
|
67
|
+
|
68
|
+
if isinstance(device, int):
|
69
|
+
device = torch.device(f"cuda:{device}")
|
70
|
+
elif isinstance(device, str):
|
71
|
+
device = torch.device(device)
|
72
|
+
torch.cuda.set_device(device)
|
73
|
+
self.dtype = torch.bfloat16
|
74
|
+
self.device = device
|
75
|
+
self.group = group
|
76
|
+
self.world_size = dist.get_world_size(self.group)
|
77
|
+
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
78
|
+
if self.device_capability < 9:
|
79
|
+
logger.warning(
|
80
|
+
"SymmMemCommunicator: Device capability %s not supported, "
|
81
|
+
"communicator is not available.",
|
82
|
+
self.device_capability,
|
83
|
+
)
|
84
|
+
return
|
85
|
+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
86
|
+
logger.warning(
|
87
|
+
"SymmMemCommunicator: World size %d not supported, "
|
88
|
+
"communicator is not available.",
|
89
|
+
self.world_size,
|
90
|
+
)
|
91
|
+
return
|
92
|
+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
93
|
+
self.world_size
|
94
|
+
]
|
95
|
+
self.buffer = torch_symm_mem.empty(
|
96
|
+
self.max_size // self.dtype.itemsize,
|
97
|
+
device=self.device,
|
98
|
+
dtype=self.dtype,
|
99
|
+
)
|
100
|
+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
101
|
+
if handle.multicast_ptr == 0:
|
102
|
+
logger.warning(
|
103
|
+
"SymmMemCommunicator: symmetric memory "
|
104
|
+
"multicast operations are not supported."
|
105
|
+
)
|
106
|
+
self.buffer = None
|
107
|
+
self.disabled = True
|
108
|
+
return
|
109
|
+
self.disabled = False
|
110
|
+
|
111
|
+
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
112
|
+
"""
|
113
|
+
Fast-path eligibility check for a given tensor.
|
114
|
+
|
115
|
+
Conditions:
|
116
|
+
- Communicator must be enabled.
|
117
|
+
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
118
|
+
- Total byte size must be 4-byte aligned (hardware requirement).
|
119
|
+
- Payload must be smaller than the symmetric-memory max size.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
True if the symmetric-memory path can handle this tensor.
|
123
|
+
"""
|
124
|
+
if self.disabled:
|
125
|
+
return False
|
126
|
+
if inp.dtype != self.dtype:
|
127
|
+
return False
|
128
|
+
inp_size = inp.numel() * inp.element_size()
|
129
|
+
# enforce 4-byte alignment
|
130
|
+
if inp_size % 4 != 0:
|
131
|
+
return False
|
132
|
+
return inp_size < self.max_size
|
133
|
+
|
134
|
+
def all_reduce(
|
135
|
+
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
136
|
+
) -> Optional[torch.Tensor]:
|
137
|
+
"""
|
138
|
+
Perform an in-place sum all-reduce via symmetric memory.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
inp: Input tensor on the target CUDA device (bfloat16).
|
142
|
+
out: Optional output tensor; if omitted, a new tensor is allocated.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
The reduced tensor (same shape as inp), or None if disabled.
|
146
|
+
|
147
|
+
Implementation details:
|
148
|
+
- Stages 'inp' into the symmetric buffer.
|
149
|
+
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
150
|
+
- Writes the result into 'out' and returns it.
|
151
|
+
"""
|
152
|
+
if out is None:
|
153
|
+
out = torch.empty_like(inp)
|
154
|
+
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
155
|
+
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
156
|
+
torch.ops.symm_mem.multimem_all_reduce_(
|
157
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
torch.ops.symm_mem.two_shot_all_reduce_(
|
161
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
162
|
+
)
|
163
|
+
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
164
|
+
return out
|