sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@
|
|
4
4
|
# Adapted from
|
5
5
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
6
6
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
7
|
-
"""
|
7
|
+
"""Distributed state.
|
8
8
|
It takes over the control of the distributed environment from PyTorch.
|
9
9
|
The typical workflow is:
|
10
10
|
|
@@ -53,16 +53,26 @@ from sglang.srt.utils import (
|
|
53
53
|
|
54
54
|
_is_npu = is_npu()
|
55
55
|
_is_cpu = is_cpu()
|
56
|
+
_supports_custom_op = supports_custom_op()
|
56
57
|
|
57
58
|
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
58
59
|
|
59
60
|
|
61
|
+
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
62
|
+
|
63
|
+
# use int value instead of ReduceOp.SUM to support torch compile
|
64
|
+
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
65
|
+
|
66
|
+
|
60
67
|
@dataclass
|
61
68
|
class GraphCaptureContext:
|
62
69
|
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
63
70
|
|
64
71
|
|
65
|
-
|
72
|
+
@dataclass
|
73
|
+
class P2PWork:
|
74
|
+
work: Optional[torch.distributed.Work]
|
75
|
+
payload: Optional[torch.Tensor]
|
66
76
|
|
67
77
|
|
68
78
|
def _split_tensor_dict(
|
@@ -114,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
|
|
114
124
|
_groups[group.unique_name] = weakref.ref(group)
|
115
125
|
|
116
126
|
|
117
|
-
if
|
127
|
+
if _supports_custom_op:
|
118
128
|
|
119
129
|
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
120
130
|
assert group_name in _groups, f"Group {group_name} is not found."
|
@@ -205,12 +215,14 @@ class GroupCoordinator:
|
|
205
215
|
use_pynccl: bool # a hint of whether to use PyNccl
|
206
216
|
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
207
217
|
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
218
|
+
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
|
208
219
|
use_message_queue_broadcaster: (
|
209
220
|
bool # a hint of whether to use message queue broadcaster
|
210
221
|
)
|
211
222
|
# communicators are only created for world size > 1
|
212
223
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
213
224
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
225
|
+
symm_mem_comm: Optional[Any] # Symm mem communicator
|
214
226
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
215
227
|
|
216
228
|
def __init__(
|
@@ -221,6 +233,7 @@ class GroupCoordinator:
|
|
221
233
|
use_pynccl: bool,
|
222
234
|
use_pymscclpp: bool,
|
223
235
|
use_custom_allreduce: bool,
|
236
|
+
use_torch_symm_mem: bool,
|
224
237
|
use_hpu_communicator: bool,
|
225
238
|
use_xpu_communicator: bool,
|
226
239
|
use_npu_communicator: bool,
|
@@ -269,12 +282,13 @@ class GroupCoordinator:
|
|
269
282
|
self.use_pynccl = use_pynccl
|
270
283
|
self.use_pymscclpp = use_pymscclpp
|
271
284
|
self.use_custom_allreduce = use_custom_allreduce
|
285
|
+
self.use_torch_symm_mem = use_torch_symm_mem
|
272
286
|
self.use_hpu_communicator = use_hpu_communicator
|
273
287
|
self.use_xpu_communicator = use_xpu_communicator
|
274
288
|
self.use_npu_communicator = use_npu_communicator
|
275
289
|
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
276
290
|
|
277
|
-
#
|
291
|
+
# Lazy import to avoid documentation build error
|
278
292
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
279
293
|
CustomAllreduce,
|
280
294
|
)
|
@@ -284,6 +298,9 @@ class GroupCoordinator:
|
|
284
298
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
285
299
|
PyNcclCommunicator,
|
286
300
|
)
|
301
|
+
from sglang.srt.distributed.device_communicators.symm_mem import (
|
302
|
+
SymmMemCommunicator,
|
303
|
+
)
|
287
304
|
|
288
305
|
if is_hip():
|
289
306
|
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
@@ -332,6 +349,13 @@ class GroupCoordinator:
|
|
332
349
|
except Exception as e:
|
333
350
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
334
351
|
|
352
|
+
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
353
|
+
if self.use_torch_symm_mem and self.world_size > 1:
|
354
|
+
self.symm_mem_comm = SymmMemCommunicator(
|
355
|
+
group=self.cpu_group,
|
356
|
+
device=self.device,
|
357
|
+
)
|
358
|
+
|
335
359
|
# Create communicator for other hardware backends
|
336
360
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
337
361
|
HpuCommunicator,
|
@@ -436,6 +460,7 @@ class GroupCoordinator:
|
|
436
460
|
# custom allreduce | enabled | enabled |
|
437
461
|
# PyNccl | disabled| enabled |
|
438
462
|
# PyMscclpp | disabled| enabled |
|
463
|
+
# TorchSymmMem | disabled| enabled |
|
439
464
|
# torch.distributed | enabled | disabled|
|
440
465
|
#
|
441
466
|
# Note: When custom quick allreduce is enabled, a runtime check
|
@@ -489,14 +514,12 @@ class GroupCoordinator:
|
|
489
514
|
|
490
515
|
if input_.is_cpu:
|
491
516
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
492
|
-
torch.ops.sgl_kernel.shm_allreduce(
|
493
|
-
input_, torch.distributed.ReduceOp.SUM
|
494
|
-
)
|
517
|
+
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
|
495
518
|
else:
|
496
519
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
497
520
|
return input_
|
498
521
|
|
499
|
-
if not
|
522
|
+
if not _supports_custom_op:
|
500
523
|
self._all_reduce_in_place(input_)
|
501
524
|
return input_
|
502
525
|
|
@@ -522,23 +545,29 @@ class GroupCoordinator:
|
|
522
545
|
|
523
546
|
outplace_all_reduce_method = None
|
524
547
|
if (
|
525
|
-
self.qr_comm is not None
|
526
|
-
and not self.qr_comm.disabled
|
527
|
-
and self.qr_comm.should_quick_allreduce(input_)
|
528
|
-
):
|
529
|
-
outplace_all_reduce_method = "qr"
|
530
|
-
elif (
|
531
548
|
self.ca_comm is not None
|
532
549
|
and not self.ca_comm.disabled
|
533
550
|
and self.ca_comm.should_custom_ar(input_)
|
534
551
|
):
|
535
552
|
outplace_all_reduce_method = "ca"
|
553
|
+
elif (
|
554
|
+
self.qr_comm is not None
|
555
|
+
and not self.qr_comm.disabled
|
556
|
+
and self.qr_comm.should_quick_allreduce(input_)
|
557
|
+
):
|
558
|
+
outplace_all_reduce_method = "qr"
|
536
559
|
elif (
|
537
560
|
self.pymscclpp_comm is not None
|
538
561
|
and not self.pymscclpp_comm.disabled
|
539
562
|
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
540
563
|
):
|
541
564
|
outplace_all_reduce_method = "pymscclpp"
|
565
|
+
elif (
|
566
|
+
self.symm_mem_comm is not None
|
567
|
+
and not self.symm_mem_comm.disabled
|
568
|
+
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
|
569
|
+
):
|
570
|
+
outplace_all_reduce_method = "symm_mem"
|
542
571
|
if outplace_all_reduce_method is not None:
|
543
572
|
return torch.ops.sglang.outplace_all_reduce(
|
544
573
|
input_,
|
@@ -552,16 +581,20 @@ class GroupCoordinator:
|
|
552
581
|
def _all_reduce_out_place(
|
553
582
|
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
554
583
|
) -> torch.Tensor:
|
555
|
-
qr_comm = self.qr_comm
|
556
584
|
ca_comm = self.ca_comm
|
585
|
+
qr_comm = self.qr_comm
|
557
586
|
pymscclpp_comm = self.pymscclpp_comm
|
587
|
+
symm_mem_comm = self.symm_mem_comm
|
558
588
|
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
559
|
-
if outplace_all_reduce_method == "
|
560
|
-
assert not qr_comm.disabled
|
561
|
-
out = qr_comm.quick_all_reduce(input_)
|
562
|
-
elif outplace_all_reduce_method == "ca":
|
589
|
+
if outplace_all_reduce_method == "ca":
|
563
590
|
assert not ca_comm.disabled
|
564
591
|
out = ca_comm.custom_all_reduce(input_)
|
592
|
+
elif outplace_all_reduce_method == "qr":
|
593
|
+
assert not qr_comm.disabled
|
594
|
+
out = qr_comm.quick_all_reduce(input_)
|
595
|
+
elif outplace_all_reduce_method == "symm_mem":
|
596
|
+
assert not symm_mem_comm.disabled
|
597
|
+
out = symm_mem_comm.all_reduce(input_)
|
565
598
|
else:
|
566
599
|
assert not pymscclpp_comm.disabled
|
567
600
|
out = pymscclpp_comm.all_reduce(input_)
|
@@ -636,7 +669,7 @@ class GroupCoordinator:
|
|
636
669
|
)
|
637
670
|
|
638
671
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
639
|
-
if _is_npu or not
|
672
|
+
if _is_npu or not _supports_custom_op:
|
640
673
|
self._all_gather_into_tensor(output, input)
|
641
674
|
else:
|
642
675
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
@@ -696,15 +729,13 @@ class GroupCoordinator:
|
|
696
729
|
)
|
697
730
|
|
698
731
|
# All-gather.
|
699
|
-
if input_.is_cpu and is_shm_available(
|
700
|
-
input_.dtype, self.world_size, self.local_size
|
701
|
-
):
|
702
|
-
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
703
|
-
|
704
732
|
if input_.is_cpu:
|
705
|
-
|
706
|
-
|
707
|
-
|
733
|
+
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
734
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
735
|
+
else:
|
736
|
+
torch.distributed.all_gather_into_tensor(
|
737
|
+
output_tensor, input_, group=self.device_group
|
738
|
+
)
|
708
739
|
else:
|
709
740
|
self.all_gather_into_tensor(output_tensor, input_)
|
710
741
|
|
@@ -860,76 +891,89 @@ class GroupCoordinator:
|
|
860
891
|
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
861
892
|
return objs
|
862
893
|
|
863
|
-
def send_object(
|
864
|
-
|
865
|
-
|
894
|
+
def send_object(
|
895
|
+
self,
|
896
|
+
obj: Any,
|
897
|
+
dst: int,
|
898
|
+
async_send: bool = False,
|
899
|
+
) -> List[P2PWork]:
|
900
|
+
"""
|
901
|
+
Send the input object list to the destination rank.
|
902
|
+
This function uses the CPU group for all communications.
|
866
903
|
|
867
|
-
|
904
|
+
TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
|
905
|
+
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
|
868
906
|
|
907
|
+
NOTE: `dst` is the local rank of the destination rank.
|
908
|
+
"""
|
909
|
+
|
910
|
+
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
869
911
|
assert dst != self.rank_in_group, (
|
870
912
|
"Invalid destination rank. Destination rank is the same "
|
871
913
|
"as the current rank."
|
872
914
|
)
|
915
|
+
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
873
916
|
|
874
917
|
# Serialize object to tensor and get the size as well
|
875
|
-
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
876
|
-
device=torch.cuda.current_device()
|
877
|
-
)
|
878
|
-
|
918
|
+
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
879
919
|
size_tensor = torch.tensor(
|
880
|
-
[object_tensor.numel()],
|
881
|
-
dtype=torch.long,
|
882
|
-
device=torch.cuda.current_device(),
|
920
|
+
[object_tensor.numel()], dtype=torch.long, device="cpu"
|
883
921
|
)
|
884
922
|
|
885
923
|
# Send object size
|
886
|
-
|
887
|
-
|
924
|
+
p2p_work = []
|
925
|
+
size_work = send_func(
|
926
|
+
size_tensor,
|
927
|
+
self.ranks[dst],
|
928
|
+
group=self.cpu_group,
|
888
929
|
)
|
930
|
+
if async_send:
|
931
|
+
p2p_work.append(P2PWork(size_work, size_tensor))
|
889
932
|
|
890
|
-
|
891
|
-
|
892
|
-
|
933
|
+
object_work = send_func(
|
934
|
+
object_tensor,
|
935
|
+
self.ranks[dst],
|
936
|
+
group=self.cpu_group,
|
893
937
|
)
|
938
|
+
if async_send:
|
939
|
+
p2p_work.append(P2PWork(object_work, object_tensor))
|
894
940
|
|
895
|
-
return
|
941
|
+
return p2p_work
|
896
942
|
|
897
|
-
def recv_object(
|
943
|
+
def recv_object(
|
944
|
+
self,
|
945
|
+
src: int,
|
946
|
+
) -> Any:
|
898
947
|
"""Receive the input object list from the source rank."""
|
899
948
|
"""NOTE: `src` is the local rank of the source rank."""
|
900
949
|
|
901
950
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
902
|
-
|
903
951
|
assert (
|
904
952
|
src != self.rank_in_group
|
905
953
|
), "Invalid source rank. Source rank is the same as the current rank."
|
906
954
|
|
907
|
-
size_tensor = torch.empty(
|
908
|
-
1, dtype=torch.long, device=torch.cuda.current_device()
|
909
|
-
)
|
955
|
+
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
910
956
|
|
911
957
|
# Receive object size
|
912
|
-
|
913
|
-
|
958
|
+
# We have to use irecv here to make it work for both isend and send.
|
959
|
+
work = torch.distributed.irecv(
|
960
|
+
size_tensor, src=self.ranks[src], group=self.cpu_group
|
914
961
|
)
|
962
|
+
work.wait()
|
915
963
|
|
916
964
|
# Tensor to receive serialized objects into.
|
917
|
-
object_tensor = torch.empty( # type: ignore[call-overload]
|
965
|
+
object_tensor: Any = torch.empty( # type: ignore[call-overload]
|
918
966
|
size_tensor.item(), # type: ignore[arg-type]
|
919
967
|
dtype=torch.uint8,
|
920
|
-
device=
|
968
|
+
device="cpu",
|
921
969
|
)
|
922
970
|
|
923
|
-
|
924
|
-
object_tensor, src=self.ranks[src], group=self.
|
971
|
+
work = torch.distributed.irecv(
|
972
|
+
object_tensor, src=self.ranks[src], group=self.cpu_group
|
925
973
|
)
|
974
|
+
work.wait()
|
926
975
|
|
927
|
-
|
928
|
-
rank_object == rank_size
|
929
|
-
), "Received object sender rank does not match the size sender rank."
|
930
|
-
|
931
|
-
obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
|
932
|
-
|
976
|
+
obj = pickle.loads(object_tensor.numpy())
|
933
977
|
return obj
|
934
978
|
|
935
979
|
def broadcast_tensor_dict(
|
@@ -1019,12 +1063,13 @@ class GroupCoordinator:
|
|
1019
1063
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
1020
1064
|
dst: Optional[int] = None,
|
1021
1065
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
1022
|
-
|
1066
|
+
async_send: bool = False,
|
1067
|
+
) -> Optional[List[P2PWork]]:
|
1023
1068
|
"""Send the input tensor dictionary.
|
1024
1069
|
NOTE: `dst` is the local rank of the source rank.
|
1025
1070
|
"""
|
1026
1071
|
# Bypass the function if we are using only 1 GPU.
|
1027
|
-
if
|
1072
|
+
if self.world_size == 1:
|
1028
1073
|
return tensor_dict
|
1029
1074
|
|
1030
1075
|
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
@@ -1049,7 +1094,10 @@ class GroupCoordinator:
|
|
1049
1094
|
# 1. Superior D2D transfer bandwidth
|
1050
1095
|
# 2. Ability to overlap send and recv operations
|
1051
1096
|
# Thus the net performance gain justifies this approach.
|
1052
|
-
|
1097
|
+
|
1098
|
+
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
1099
|
+
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
|
1100
|
+
|
1053
1101
|
for tensor in tensor_list:
|
1054
1102
|
if tensor.numel() == 0:
|
1055
1103
|
# Skip sending empty tensors.
|
@@ -1059,15 +1107,11 @@ class GroupCoordinator:
|
|
1059
1107
|
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
|
1060
1108
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
1061
1109
|
|
1062
|
-
if tensor.is_cpu
|
1063
|
-
|
1064
|
-
|
1065
|
-
|
1066
|
-
|
1067
|
-
else:
|
1068
|
-
# use group for GPU tensors
|
1069
|
-
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
1070
|
-
return None
|
1110
|
+
comm_group = metadata_group if tensor.is_cpu else group
|
1111
|
+
work = send_func(tensor, self.ranks[dst], group=comm_group)
|
1112
|
+
if async_send:
|
1113
|
+
p2p_works.append(P2PWork(work, tensor))
|
1114
|
+
return p2p_works
|
1071
1115
|
|
1072
1116
|
def recv_tensor_dict(
|
1073
1117
|
self,
|
@@ -1113,17 +1157,15 @@ class GroupCoordinator:
|
|
1113
1157
|
orig_shape = tensor.shape
|
1114
1158
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
1115
1159
|
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
|
1122
|
-
|
1123
|
-
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
|
1160
|
+
# We have to use irecv here to make it work for both isend and send.
|
1161
|
+
comm_group = metadata_group if tensor.is_cpu else group
|
1162
|
+
work = torch.distributed.irecv(
|
1163
|
+
tensor, src=self.ranks[src], group=comm_group
|
1164
|
+
)
|
1165
|
+
work.wait()
|
1166
|
+
|
1124
1167
|
if use_all_gather:
|
1125
|
-
|
1126
|
-
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
|
1168
|
+
tensor = all_gather_group.all_gather(tensor, dim=0)
|
1127
1169
|
tensor = tensor.reshape(orig_shape)
|
1128
1170
|
|
1129
1171
|
tensor_dict[key] = tensor
|
@@ -1201,6 +1243,7 @@ def init_world_group(
|
|
1201
1243
|
use_pynccl=False,
|
1202
1244
|
use_pymscclpp=False,
|
1203
1245
|
use_custom_allreduce=False,
|
1246
|
+
use_torch_symm_mem=False,
|
1204
1247
|
use_hpu_communicator=False,
|
1205
1248
|
use_xpu_communicator=False,
|
1206
1249
|
use_npu_communicator=False,
|
@@ -1216,11 +1259,14 @@ def init_model_parallel_group(
|
|
1216
1259
|
use_message_queue_broadcaster: bool = False,
|
1217
1260
|
group_name: Optional[str] = None,
|
1218
1261
|
use_mscclpp_allreduce: Optional[bool] = None,
|
1262
|
+
use_symm_mem_allreduce: Optional[bool] = None,
|
1219
1263
|
) -> GroupCoordinator:
|
1220
1264
|
if use_custom_allreduce is None:
|
1221
1265
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
1222
1266
|
if use_mscclpp_allreduce is None:
|
1223
1267
|
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
1268
|
+
if use_symm_mem_allreduce is None:
|
1269
|
+
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
|
1224
1270
|
return GroupCoordinator(
|
1225
1271
|
group_ranks=group_ranks,
|
1226
1272
|
local_rank=local_rank,
|
@@ -1228,6 +1274,7 @@ def init_model_parallel_group(
|
|
1228
1274
|
use_pynccl=not _is_npu,
|
1229
1275
|
use_pymscclpp=use_mscclpp_allreduce,
|
1230
1276
|
use_custom_allreduce=use_custom_allreduce,
|
1277
|
+
use_torch_symm_mem=use_symm_mem_allreduce,
|
1231
1278
|
use_hpu_communicator=True,
|
1232
1279
|
use_xpu_communicator=True,
|
1233
1280
|
use_npu_communicator=True,
|
@@ -1313,6 +1360,7 @@ logger = logging.getLogger(__name__)
|
|
1313
1360
|
|
1314
1361
|
_ENABLE_CUSTOM_ALL_REDUCE = True
|
1315
1362
|
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
1363
|
+
_ENABLE_SYMM_MEM_ALL_REDUCE = False
|
1316
1364
|
|
1317
1365
|
|
1318
1366
|
def set_custom_all_reduce(enable: bool):
|
@@ -1325,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
|
|
1325
1373
|
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
1326
1374
|
|
1327
1375
|
|
1376
|
+
def set_symm_mem_all_reduce(enable: bool):
|
1377
|
+
global _ENABLE_SYMM_MEM_ALL_REDUCE
|
1378
|
+
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
|
1379
|
+
|
1380
|
+
|
1328
1381
|
def init_distributed_environment(
|
1329
1382
|
world_size: int = -1,
|
1330
1383
|
rank: int = -1,
|
@@ -1461,43 +1514,49 @@ def initialize_model_parallel(
|
|
1461
1514
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
1462
1515
|
|
1463
1516
|
moe_ep_size = expert_model_parallel_size
|
1464
|
-
|
1465
1517
|
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
1518
|
+
|
1466
1519
|
global _MOE_EP
|
1467
1520
|
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
1468
|
-
group_ranks = []
|
1469
|
-
for i in range(num_tensor_model_parallel_groups):
|
1470
|
-
for j in range(moe_tp_size):
|
1471
|
-
st = i * tensor_model_parallel_size + j
|
1472
|
-
en = (i + 1) * tensor_model_parallel_size + j
|
1473
|
-
ranks = list(range(st, en, moe_tp_size))
|
1474
|
-
group_ranks.append(ranks)
|
1475
1521
|
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1522
|
+
if moe_ep_size == tensor_model_parallel_size:
|
1523
|
+
_MOE_EP = _TP
|
1524
|
+
else:
|
1525
|
+
# TODO(ch-wan): use split_group to save memory
|
1526
|
+
group_ranks = []
|
1527
|
+
for i in range(num_tensor_model_parallel_groups):
|
1528
|
+
for j in range(moe_tp_size):
|
1529
|
+
st = i * tensor_model_parallel_size + j
|
1530
|
+
en = (i + 1) * tensor_model_parallel_size + j
|
1531
|
+
ranks = list(range(st, en, moe_tp_size))
|
1532
|
+
group_ranks.append(ranks)
|
1533
|
+
_MOE_EP = init_model_parallel_group(
|
1534
|
+
group_ranks,
|
1535
|
+
get_world_group().local_rank,
|
1536
|
+
backend,
|
1537
|
+
group_name="moe_ep",
|
1538
|
+
)
|
1483
1539
|
|
1484
1540
|
global _MOE_TP
|
1485
1541
|
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
1486
|
-
group_ranks = []
|
1487
|
-
for i in range(num_tensor_model_parallel_groups):
|
1488
|
-
for j in range(moe_ep_size):
|
1489
|
-
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1490
|
-
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1491
|
-
ranks = list(range(st, en))
|
1492
|
-
group_ranks.append(ranks)
|
1493
1542
|
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1543
|
+
if moe_tp_size == tensor_model_parallel_size:
|
1544
|
+
_MOE_TP = _TP
|
1545
|
+
else:
|
1546
|
+
# TODO(ch-wan): use split_group to save memory
|
1547
|
+
group_ranks = []
|
1548
|
+
for i in range(num_tensor_model_parallel_groups):
|
1549
|
+
for j in range(moe_ep_size):
|
1550
|
+
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1551
|
+
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1552
|
+
ranks = list(range(st, en))
|
1553
|
+
group_ranks.append(ranks)
|
1554
|
+
_MOE_TP = init_model_parallel_group(
|
1555
|
+
group_ranks,
|
1556
|
+
get_world_group().local_rank,
|
1557
|
+
backend,
|
1558
|
+
group_name="moe_tp",
|
1559
|
+
)
|
1501
1560
|
|
1502
1561
|
# Build the pipeline model-parallel groups.
|
1503
1562
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
@@ -1583,6 +1642,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
1583
1642
|
_TP = old_tp_group
|
1584
1643
|
|
1585
1644
|
|
1645
|
+
def get_world_size():
|
1646
|
+
"""Return world size for the world group."""
|
1647
|
+
return get_world_group().world_size
|
1648
|
+
|
1649
|
+
|
1650
|
+
def get_world_rank():
|
1651
|
+
"""Return my rank for the world group."""
|
1652
|
+
return get_world_group().rank_in_group
|
1653
|
+
|
1654
|
+
|
1586
1655
|
def get_tensor_model_parallel_world_size():
|
1587
1656
|
"""Return world size for the tensor model parallel group."""
|
1588
1657
|
return get_tp_group().world_size
|
@@ -1593,6 +1662,16 @@ def get_tensor_model_parallel_rank():
|
|
1593
1662
|
return get_tp_group().rank_in_group
|
1594
1663
|
|
1595
1664
|
|
1665
|
+
def get_pipeline_model_parallel_world_size():
|
1666
|
+
"""Return world size for the pipeline model parallel group."""
|
1667
|
+
return get_pp_group().world_size
|
1668
|
+
|
1669
|
+
|
1670
|
+
def get_pipeline_model_parallel_rank():
|
1671
|
+
"""Return my rank for the pipeline model parallel group."""
|
1672
|
+
return get_pp_group().rank_in_group
|
1673
|
+
|
1674
|
+
|
1596
1675
|
def get_moe_expert_parallel_world_size():
|
1597
1676
|
"""Return world size for the moe expert parallel group."""
|
1598
1677
|
return get_moe_ep_group().world_size
|