sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +54 -37
- sglang/bench_one_batch_server.py +340 -34
- sglang/bench_serving.py +340 -159
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +9 -2
- sglang/profiler.py +20 -3
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +309 -0
- sglang/srt/configs/load_config.py +33 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +284 -118
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +576 -0
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +6 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +26 -15
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +268 -98
- sglang/srt/disaggregation/decode.py +172 -39
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +203 -555
- sglang/srt/disaggregation/nixl/conn.py +217 -63
- sglang/srt/disaggregation/prefill.py +113 -270
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +203 -97
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +85 -65
- sglang/srt/entrypoints/grpc_server.py +632 -305
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +169 -17
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +327 -34
- sglang/srt/entrypoints/openai/serving_base.py +74 -8
- sglang/srt/entrypoints/openai/serving_chat.py +202 -118
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +20 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +47 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +323 -0
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +21 -16
- sglang/srt/function_call/glm4_moe_detector.py +4 -8
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +61 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +98 -7
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/grpc_request_manager.py +915 -0
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
- sglang/srt/layers/activation.py +11 -7
- sglang/srt/layers/attention/aiter_backend.py +17 -18
- sglang/srt/layers/attention/ascend_backend.py +125 -10
- sglang/srt/layers/attention/attention_registry.py +226 -0
- sglang/srt/layers/attention/base_attn_backend.py +32 -4
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +52 -15
- sglang/srt/layers/attention/flashinfer_backend.py +357 -212
- sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
- sglang/srt/layers/attention/flashmla_backend.py +9 -7
- sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
- sglang/srt/layers/attention/mamba/mamba.py +514 -1
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +23 -0
- sglang/srt/layers/attention/nsa_backend.py +1201 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +249 -42
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
- sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +61 -3
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +19 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +28 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +47 -15
- sglang/srt/layers/linear.py +30 -5
- sglang/srt/layers/logits_processor.py +161 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
- sglang/srt/layers/moe/ep_moe/layer.py +243 -448
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +27 -1
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +86 -20
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +43 -15
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +141 -81
- sglang/srt/layers/quantization/mxfp4.py +17 -34
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -24
- sglang/srt/layers/quantization/w8a8_int8.py +45 -27
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +750 -46
- sglang/srt/layers/sampler.py +84 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +23 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +9 -4
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +33 -7
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +41 -17
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +83 -152
- sglang/srt/managers/data_parallel_controller.py +156 -87
- sglang/srt/managers/detokenizer_manager.py +51 -24
- sglang/srt/managers/io_struct.py +223 -129
- sglang/srt/managers/mm_utils.py +49 -10
- sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +130 -0
- sglang/srt/managers/schedule_batch.py +340 -529
- sglang/srt/managers/schedule_policy.py +158 -18
- sglang/srt/managers/scheduler.py +665 -620
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
- sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
- sglang/srt/managers/tokenizer_manager.py +462 -226
- sglang/srt/managers/tp_worker.py +217 -156
- sglang/srt/managers/utils.py +79 -47
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +42 -28
- sglang/srt/mem_cache/base_prefix_cache.py +3 -3
- sglang/srt/mem_cache/chunk_cache.py +20 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +38 -0
- sglang/srt/mem_cache/hicache_storage.py +44 -2
- sglang/srt/mem_cache/hiradix_cache.py +134 -34
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +602 -208
- sglang/srt/mem_cache/memory_pool_host.py +134 -183
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +263 -78
- sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +115 -58
- sglang/srt/metrics/collector.py +113 -120
- sglang/srt/metrics/func_timer.py +3 -8
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +81 -36
- sglang/srt/model_executor/forward_batch_info.py +40 -50
- sglang/srt/model_executor/model_runner.py +507 -319
- sglang/srt/model_executor/npu_graph_runner.py +11 -5
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +438 -37
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +200 -27
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +40 -56
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +25 -4
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +793 -235
- sglang/srt/models/dots_ocr.py +171 -0
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +570 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -3
- sglang/srt/models/glm4_moe.py +17 -40
- sglang/srt/models/glm4_moe_nextn.py +4 -4
- sglang/srt/models/glm4v.py +3 -2
- sglang/srt/models/glm4v_moe.py +6 -6
- sglang/srt/models/gpt_oss.py +12 -35
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +4 -2
- sglang/srt/models/llama.py +6 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +6 -23
- sglang/srt/models/longcat_flash_nextn.py +4 -15
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +27 -6
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +5 -5
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +70 -4
- sglang/srt/models/qwen2_vl.py +6 -3
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +50 -38
- sglang/srt/models/qwen3_next.py +43 -21
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +791 -0
- sglang/srt/models/qwen3_vl_moe.py +343 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +268 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +61 -0
- sglang/srt/multimodal/processors/base_processor.py +21 -9
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +2 -4
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +20 -10
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +83 -17
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +36 -23
- sglang/srt/sampling/sampling_params.py +75 -0
- sglang/srt/server_args.py +1300 -338
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +161 -0
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
- sglang/srt/speculative/eagle_info.py +786 -0
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +113 -1270
- sglang/srt/speculative/eagle_worker.py +120 -285
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/ngram_info.py +433 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +49 -0
- sglang/srt/speculative/spec_utils.py +641 -0
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +35 -18
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/{utils.py → utils/common.py} +583 -113
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +120 -11
- sglang/test/runners.py +3 -1
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +8 -2
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +3 -4
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +430 -0
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +93 -1
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +432 -16
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
- sglang/srt/entrypoints/grpc_request_manager.py +0 -580
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
|
|
|
18
18
|
is_weak_contiguous,
|
|
19
19
|
)
|
|
20
20
|
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
|
21
|
-
from sglang.srt.utils import is_cuda, is_hip
|
|
21
|
+
from sglang.srt.utils import is_cuda, is_hip, log_info_on_rank0
|
|
22
22
|
|
|
23
23
|
logger = logging.getLogger(__name__)
|
|
24
24
|
|
|
@@ -32,7 +32,7 @@ try:
|
|
|
32
32
|
ops.meta_size()
|
|
33
33
|
else:
|
|
34
34
|
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
|
|
35
|
-
import sgl_kernel
|
|
35
|
+
import sgl_kernel # noqa: F401
|
|
36
36
|
custom_ar = True
|
|
37
37
|
except Exception:
|
|
38
38
|
# For CPUs
|
|
@@ -185,7 +185,7 @@ class CustomAllreduce:
|
|
|
185
185
|
# is enough for 131072 such tuples. The largest model I've seen only
|
|
186
186
|
# needs less than 10000 of registered tuples.
|
|
187
187
|
self.rank_data = torch.empty(
|
|
188
|
-
|
|
188
|
+
max_size, dtype=torch.uint8, device=self.device
|
|
189
189
|
)
|
|
190
190
|
self._ptr = ops.init_custom_ar(
|
|
191
191
|
self.meta_ptrs, self.rank_data, rank, self.full_nvlink
|
|
@@ -202,7 +202,7 @@ class CustomAllreduce:
|
|
|
202
202
|
)
|
|
203
203
|
handles, offsets = self._gather_ipc_meta(shard_data)
|
|
204
204
|
self.rank_data = torch.empty(
|
|
205
|
-
|
|
205
|
+
max_size, dtype=torch.uint8, device=self.device
|
|
206
206
|
)
|
|
207
207
|
self._ptr = ops.init_custom_ar(
|
|
208
208
|
self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
|
|
@@ -301,11 +301,11 @@ class CustomAllreduce:
|
|
|
301
301
|
if _is_hip:
|
|
302
302
|
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
|
303
303
|
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
|
|
304
|
-
logger
|
|
304
|
+
log_info_on_rank0(logger, f"Registering {len(offset)} cuda graph addresses")
|
|
305
305
|
ops.register_graph_buffers(self._ptr, handles, offsets)
|
|
306
306
|
else:
|
|
307
307
|
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
|
|
308
|
-
logger
|
|
308
|
+
log_info_on_rank0(logger, f"Registering {len(offset)} cuda graph addresses")
|
|
309
309
|
# We cannot directly use `dist.all_gather_object` here
|
|
310
310
|
# because it is incompatible with `gloo` backend under inference mode.
|
|
311
311
|
# see https://github.com/pytorch/pytorch/issues/126032 for details.
|
|
@@ -4,7 +4,7 @@ import math
|
|
|
4
4
|
import os
|
|
5
5
|
from contextlib import contextmanager
|
|
6
6
|
from enum import IntEnum
|
|
7
|
-
from typing import
|
|
7
|
+
from typing import Optional, Union
|
|
8
8
|
|
|
9
9
|
import torch
|
|
10
10
|
import torch.distributed as dist
|
|
@@ -24,7 +24,7 @@ if _is_hip:
|
|
|
24
24
|
mscclpp_is_available = False
|
|
25
25
|
if _is_cuda:
|
|
26
26
|
try:
|
|
27
|
-
import sgl_kernel
|
|
27
|
+
import sgl_kernel # noqa: F401
|
|
28
28
|
|
|
29
29
|
mscclpp_is_available = True
|
|
30
30
|
except:
|
|
@@ -30,6 +30,7 @@ class PyNcclCommunicator:
|
|
|
30
30
|
group: Union[ProcessGroup, StatelessProcessGroup],
|
|
31
31
|
device: Union[int, str, torch.device],
|
|
32
32
|
library_path: Optional[str] = None,
|
|
33
|
+
use_current_stream: bool = False,
|
|
33
34
|
):
|
|
34
35
|
"""
|
|
35
36
|
Args:
|
|
@@ -74,6 +75,7 @@ class PyNcclCommunicator:
|
|
|
74
75
|
|
|
75
76
|
self.available = True
|
|
76
77
|
self.disabled = False
|
|
78
|
+
self.use_current_stream = use_current_stream
|
|
77
79
|
|
|
78
80
|
self.nccl_version = self.nccl.ncclGetRawVersion()
|
|
79
81
|
if self.rank == 0:
|
|
@@ -123,6 +125,21 @@ class PyNcclCommunicator:
|
|
|
123
125
|
# when we are using CUDA graph.
|
|
124
126
|
self.disabled = True
|
|
125
127
|
|
|
128
|
+
def _resolve_stream(self, stream: Optional[torch.cuda.Stream]):
|
|
129
|
+
"""Return the stream to use for NCCL calls.
|
|
130
|
+
|
|
131
|
+
Behavior mirrors the previous inline logic:
|
|
132
|
+
- if an explicit stream is provided, return it
|
|
133
|
+
- if stream is None and self.use_current_stream is True, return
|
|
134
|
+
torch.cuda.current_stream()
|
|
135
|
+
- otherwise return the communicator's default stream (self.stream)
|
|
136
|
+
"""
|
|
137
|
+
if stream is not None:
|
|
138
|
+
return stream
|
|
139
|
+
if self.use_current_stream:
|
|
140
|
+
return torch.cuda.current_stream()
|
|
141
|
+
return self.stream
|
|
142
|
+
|
|
126
143
|
def all_reduce(
|
|
127
144
|
self, tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None
|
|
128
145
|
):
|
|
@@ -135,8 +152,7 @@ class PyNcclCommunicator:
|
|
|
135
152
|
f"this nccl communicator is created to work on {self.device}, "
|
|
136
153
|
f"but the input tensor is on {tensor.device}"
|
|
137
154
|
)
|
|
138
|
-
|
|
139
|
-
stream = self.stream
|
|
155
|
+
stream = self._resolve_stream(stream)
|
|
140
156
|
self.nccl.ncclAllReduce(
|
|
141
157
|
buffer_type(tensor.data_ptr()),
|
|
142
158
|
buffer_type(tensor.data_ptr()),
|
|
@@ -163,8 +179,7 @@ class PyNcclCommunicator:
|
|
|
163
179
|
f"this nccl communicator is created to work on {self.device}, "
|
|
164
180
|
f"but the input tensor is on {input_tensor.device}"
|
|
165
181
|
)
|
|
166
|
-
|
|
167
|
-
stream = self.stream
|
|
182
|
+
stream = self._resolve_stream(stream)
|
|
168
183
|
|
|
169
184
|
if sizes is not None:
|
|
170
185
|
split_offset = 0
|
|
@@ -210,8 +225,7 @@ class PyNcclCommunicator:
|
|
|
210
225
|
f"this nccl communicator is created to work on {self.device}, "
|
|
211
226
|
f"but the input tensor is on {input_tensor.device}"
|
|
212
227
|
)
|
|
213
|
-
|
|
214
|
-
stream = self.stream
|
|
228
|
+
stream = self._resolve_stream(stream)
|
|
215
229
|
|
|
216
230
|
if sizes is not None:
|
|
217
231
|
split_offset = 0
|
|
@@ -249,8 +263,7 @@ class PyNcclCommunicator:
|
|
|
249
263
|
f"this nccl communicator is created to work on {self.device}, "
|
|
250
264
|
f"but the input tensor is on {tensor.device}"
|
|
251
265
|
)
|
|
252
|
-
|
|
253
|
-
stream = self.stream
|
|
266
|
+
stream = self._resolve_stream(stream)
|
|
254
267
|
self.nccl.ncclSend(
|
|
255
268
|
buffer_type(tensor.data_ptr()),
|
|
256
269
|
tensor.numel(),
|
|
@@ -267,8 +280,7 @@ class PyNcclCommunicator:
|
|
|
267
280
|
f"this nccl communicator is created to work on {self.device}, "
|
|
268
281
|
f"but the input tensor is on {tensor.device}"
|
|
269
282
|
)
|
|
270
|
-
|
|
271
|
-
stream = self.stream
|
|
283
|
+
stream = self._resolve_stream(stream)
|
|
272
284
|
self.nccl.ncclRecv(
|
|
273
285
|
buffer_type(tensor.data_ptr()),
|
|
274
286
|
tensor.numel(),
|
|
@@ -285,8 +297,8 @@ class PyNcclCommunicator:
|
|
|
285
297
|
f"this nccl communicator is created to work on {self.device}, "
|
|
286
298
|
f"but the input tensor is on {tensor.device}"
|
|
287
299
|
)
|
|
288
|
-
|
|
289
|
-
|
|
300
|
+
stream = self._resolve_stream(stream)
|
|
301
|
+
|
|
290
302
|
if src == self.rank:
|
|
291
303
|
sendbuff = buffer_type(tensor.data_ptr())
|
|
292
304
|
# NCCL requires the sender also to have a receive buffer
|
|
@@ -5,7 +5,7 @@ from packaging import version
|
|
|
5
5
|
from torch.cuda.memory import CUDAPluggableAllocator
|
|
6
6
|
|
|
7
7
|
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
|
8
|
-
from sglang.srt.
|
|
8
|
+
from sglang.srt.server_args import get_global_server_args
|
|
9
9
|
|
|
10
10
|
nccl_allocator_source = """
|
|
11
11
|
#include <nccl.h>
|
|
@@ -32,7 +32,7 @@ _graph_pool_id = None
|
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
def is_symmetric_memory_enabled():
|
|
35
|
-
return
|
|
35
|
+
return get_global_server_args().enable_symm_mem
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
def set_graph_pool_id(graph_pool_id):
|
|
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
|
|
18
18
|
|
|
19
19
|
from sglang.srt.utils import (
|
|
20
20
|
format_tcp_address,
|
|
21
|
-
|
|
21
|
+
get_local_ip_auto,
|
|
22
22
|
get_open_port,
|
|
23
23
|
is_valid_ipv6_address,
|
|
24
24
|
)
|
|
@@ -191,7 +191,9 @@ class MessageQueue:
|
|
|
191
191
|
self.n_remote_reader = n_remote_reader
|
|
192
192
|
|
|
193
193
|
if connect_ip is None:
|
|
194
|
-
connect_ip =
|
|
194
|
+
connect_ip = (
|
|
195
|
+
get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
|
|
196
|
+
)
|
|
195
197
|
|
|
196
198
|
context = Context()
|
|
197
199
|
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.distributed as dist
|
|
7
|
+
from torch.distributed import ProcessGroup
|
|
8
|
+
|
|
9
|
+
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
|
10
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
|
11
|
+
)
|
|
12
|
+
from sglang.srt.utils import is_cuda, is_hip
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
import torch.distributed._symmetric_memory as torch_symm_mem
|
|
16
|
+
|
|
17
|
+
symm_mem_available = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
symm_mem_available = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
_is_cuda = is_cuda()
|
|
25
|
+
_is_hip = is_hip()
|
|
26
|
+
|
|
27
|
+
symm_mem_is_available = False
|
|
28
|
+
if _is_hip:
|
|
29
|
+
symm_mem_is_available = False
|
|
30
|
+
if _is_cuda:
|
|
31
|
+
symm_mem_is_available = True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SymmMemCommunicator:
|
|
35
|
+
"""
|
|
36
|
+
Thin wrapper around symmetric-memory collectives.
|
|
37
|
+
|
|
38
|
+
This communicator:
|
|
39
|
+
- Validates device capability and world size.
|
|
40
|
+
- Allocates a shared symmetric buffer.
|
|
41
|
+
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
|
42
|
+
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
|
43
|
+
|
|
44
|
+
If any prerequisite is not met, the instance remains disabled and will
|
|
45
|
+
decline to perform symmetric-memory all-reduce.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
# Mapping: compute capability major -> supported world sizes for multimem
|
|
49
|
+
# If the current (cc_major, world_size) is not listed, we fall back
|
|
50
|
+
# to the two-shot path.
|
|
51
|
+
_WORLD_SIZES_MULTIMEM = {
|
|
52
|
+
9: [4, 6, 8],
|
|
53
|
+
10: [6, 8],
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
|
57
|
+
"""
|
|
58
|
+
Args:
|
|
59
|
+
group: Torch process group used for rendezvous and naming.
|
|
60
|
+
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
self.disabled = True
|
|
64
|
+
|
|
65
|
+
if not symm_mem_available:
|
|
66
|
+
return
|
|
67
|
+
|
|
68
|
+
if isinstance(device, int):
|
|
69
|
+
device = torch.device(f"cuda:{device}")
|
|
70
|
+
elif isinstance(device, str):
|
|
71
|
+
device = torch.device(device)
|
|
72
|
+
torch.cuda.set_device(device)
|
|
73
|
+
self.dtype = torch.bfloat16
|
|
74
|
+
self.device = device
|
|
75
|
+
self.group = group
|
|
76
|
+
self.world_size = dist.get_world_size(self.group)
|
|
77
|
+
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
|
78
|
+
if self.device_capability < 9:
|
|
79
|
+
logger.warning(
|
|
80
|
+
"SymmMemCommunicator: Device capability %s not supported, "
|
|
81
|
+
"communicator is not available.",
|
|
82
|
+
self.device_capability,
|
|
83
|
+
)
|
|
84
|
+
return
|
|
85
|
+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
|
86
|
+
logger.warning(
|
|
87
|
+
"SymmMemCommunicator: World size %d not supported, "
|
|
88
|
+
"communicator is not available.",
|
|
89
|
+
self.world_size,
|
|
90
|
+
)
|
|
91
|
+
return
|
|
92
|
+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
|
93
|
+
self.world_size
|
|
94
|
+
]
|
|
95
|
+
self.buffer = torch_symm_mem.empty(
|
|
96
|
+
self.max_size // self.dtype.itemsize,
|
|
97
|
+
device=self.device,
|
|
98
|
+
dtype=self.dtype,
|
|
99
|
+
)
|
|
100
|
+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
|
101
|
+
if handle.multicast_ptr == 0:
|
|
102
|
+
logger.warning(
|
|
103
|
+
"SymmMemCommunicator: symmetric memory "
|
|
104
|
+
"multicast operations are not supported."
|
|
105
|
+
)
|
|
106
|
+
self.buffer = None
|
|
107
|
+
self.disabled = True
|
|
108
|
+
return
|
|
109
|
+
self.disabled = False
|
|
110
|
+
|
|
111
|
+
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
|
112
|
+
"""
|
|
113
|
+
Fast-path eligibility check for a given tensor.
|
|
114
|
+
|
|
115
|
+
Conditions:
|
|
116
|
+
- Communicator must be enabled.
|
|
117
|
+
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
|
118
|
+
- Total byte size must be 4-byte aligned (hardware requirement).
|
|
119
|
+
- Payload must be smaller than the symmetric-memory max size.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
True if the symmetric-memory path can handle this tensor.
|
|
123
|
+
"""
|
|
124
|
+
if self.disabled:
|
|
125
|
+
return False
|
|
126
|
+
if inp.dtype != self.dtype:
|
|
127
|
+
return False
|
|
128
|
+
inp_size = inp.numel() * inp.element_size()
|
|
129
|
+
# enforce 4-byte alignment
|
|
130
|
+
if inp_size % 4 != 0:
|
|
131
|
+
return False
|
|
132
|
+
return inp_size < self.max_size
|
|
133
|
+
|
|
134
|
+
def all_reduce(
|
|
135
|
+
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
|
136
|
+
) -> Optional[torch.Tensor]:
|
|
137
|
+
"""
|
|
138
|
+
Perform an in-place sum all-reduce via symmetric memory.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
inp: Input tensor on the target CUDA device (bfloat16).
|
|
142
|
+
out: Optional output tensor; if omitted, a new tensor is allocated.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The reduced tensor (same shape as inp), or None if disabled.
|
|
146
|
+
|
|
147
|
+
Implementation details:
|
|
148
|
+
- Stages 'inp' into the symmetric buffer.
|
|
149
|
+
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
|
150
|
+
- Writes the result into 'out' and returns it.
|
|
151
|
+
"""
|
|
152
|
+
if out is None:
|
|
153
|
+
out = torch.empty_like(inp)
|
|
154
|
+
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
|
155
|
+
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
|
156
|
+
torch.ops.symm_mem.multimem_all_reduce_(
|
|
157
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
torch.ops.symm_mem.two_shot_all_reduce_(
|
|
161
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
|
162
|
+
)
|
|
163
|
+
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
|
164
|
+
return out
|
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import base64
|
|
2
|
-
import os
|
|
3
1
|
import pickle
|
|
4
2
|
import time
|
|
5
3
|
from pathlib import Path
|
|
6
4
|
from typing import Any, List, Optional
|
|
7
5
|
|
|
6
|
+
import pybase64
|
|
8
7
|
import torch
|
|
9
8
|
|
|
10
9
|
from sglang.srt.utils import MultiprocessingSerializer
|
|
@@ -78,14 +77,16 @@ class NaiveDistributed:
|
|
|
78
77
|
)
|
|
79
78
|
|
|
80
79
|
_get_path(self._rank).write_text(
|
|
81
|
-
|
|
80
|
+
pybase64.b64encode(pickle.dumps(obj)).decode("utf-8") + text_postfix
|
|
82
81
|
)
|
|
83
82
|
|
|
84
83
|
def _read_one(interesting_rank: int):
|
|
85
84
|
p = _get_path(interesting_rank)
|
|
86
85
|
while True:
|
|
87
86
|
if p.exists() and (text := p.read_text()).endswith(text_postfix):
|
|
88
|
-
return pickle.loads(
|
|
87
|
+
return pickle.loads(
|
|
88
|
+
pybase64.b64decode(text[: -len(text_postfix)], validate=True)
|
|
89
|
+
)
|
|
89
90
|
time.sleep(0.001)
|
|
90
91
|
|
|
91
92
|
return [
|