sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/bf214ca22625e311a2c4c0dfbf7af19128f4919c/vllm/distributed/device_communicators/symm_mem.py
|
2
|
+
import logging
|
3
|
+
from typing import Optional, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.distributed as dist
|
7
|
+
from torch.distributed import ProcessGroup
|
8
|
+
|
9
|
+
from sglang.srt.distributed.device_communicators.all_reduce_utils import (
|
10
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES,
|
11
|
+
)
|
12
|
+
from sglang.srt.utils import get_device_capability, is_cuda, is_hip
|
13
|
+
|
14
|
+
try:
|
15
|
+
import torch.distributed._symmetric_memory as torch_symm_mem
|
16
|
+
|
17
|
+
symm_mem_available = True
|
18
|
+
except ImportError:
|
19
|
+
symm_mem_available = False
|
20
|
+
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
_is_cuda = is_cuda()
|
25
|
+
_is_hip = is_hip()
|
26
|
+
|
27
|
+
symm_mem_is_available = False
|
28
|
+
if _is_hip:
|
29
|
+
symm_mem_is_available = False
|
30
|
+
if _is_cuda:
|
31
|
+
symm_mem_is_available = True
|
32
|
+
|
33
|
+
|
34
|
+
class SymmMemCommunicator:
|
35
|
+
"""
|
36
|
+
Thin wrapper around symmetric-memory collectives.
|
37
|
+
|
38
|
+
This communicator:
|
39
|
+
- Validates device capability and world size.
|
40
|
+
- Allocates a shared symmetric buffer.
|
41
|
+
- Chooses between 'multimem' and 'two-shot' all-reduce kernels.
|
42
|
+
- Exposes a fast-path all_reduce() compatible with bfloat16 inputs.
|
43
|
+
|
44
|
+
If any prerequisite is not met, the instance remains disabled and will
|
45
|
+
decline to perform symmetric-memory all-reduce.
|
46
|
+
"""
|
47
|
+
|
48
|
+
# Mapping: compute capability major -> supported world sizes for multimem
|
49
|
+
# If the current (cc_major, world_size) is not listed, we fall back
|
50
|
+
# to the two-shot path.
|
51
|
+
_WORLD_SIZES_MULTIMEM = {
|
52
|
+
9: [4, 6, 8],
|
53
|
+
10: [6, 8],
|
54
|
+
}
|
55
|
+
|
56
|
+
def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device]):
|
57
|
+
"""
|
58
|
+
Args:
|
59
|
+
group: Torch process group used for rendezvous and naming.
|
60
|
+
device: Target CUDA device (index, 'cuda:X', or torch.device).
|
61
|
+
"""
|
62
|
+
|
63
|
+
self.disabled = True
|
64
|
+
|
65
|
+
if not symm_mem_available:
|
66
|
+
return
|
67
|
+
|
68
|
+
if isinstance(device, int):
|
69
|
+
device = torch.device(f"cuda:{device}")
|
70
|
+
elif isinstance(device, str):
|
71
|
+
device = torch.device(device)
|
72
|
+
torch.cuda.set_device(device)
|
73
|
+
self.dtype = torch.bfloat16
|
74
|
+
self.device = device
|
75
|
+
self.group = group
|
76
|
+
self.world_size = dist.get_world_size(self.group)
|
77
|
+
self.device_capability = torch.cuda.get_device_capability(device)[0]
|
78
|
+
if self.device_capability < 9:
|
79
|
+
logger.warning(
|
80
|
+
"SymmMemCommunicator: Device capability %s not supported, "
|
81
|
+
"communicator is not available.",
|
82
|
+
self.device_capability,
|
83
|
+
)
|
84
|
+
return
|
85
|
+
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability]:
|
86
|
+
logger.warning(
|
87
|
+
"SymmMemCommunicator: World size %d not supported, "
|
88
|
+
"communicator is not available.",
|
89
|
+
self.world_size,
|
90
|
+
)
|
91
|
+
return
|
92
|
+
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
|
93
|
+
self.world_size
|
94
|
+
]
|
95
|
+
self.buffer = torch_symm_mem.empty(
|
96
|
+
self.max_size // self.dtype.itemsize,
|
97
|
+
device=self.device,
|
98
|
+
dtype=self.dtype,
|
99
|
+
)
|
100
|
+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
|
101
|
+
if handle.multicast_ptr == 0:
|
102
|
+
logger.warning(
|
103
|
+
"SymmMemCommunicator: symmetric memory "
|
104
|
+
"multicast operations are not supported."
|
105
|
+
)
|
106
|
+
self.buffer = None
|
107
|
+
self.disabled = True
|
108
|
+
return
|
109
|
+
self.disabled = False
|
110
|
+
|
111
|
+
def should_symm_mem_allreduce(self, inp: torch.Tensor):
|
112
|
+
"""
|
113
|
+
Fast-path eligibility check for a given tensor.
|
114
|
+
|
115
|
+
Conditions:
|
116
|
+
- Communicator must be enabled.
|
117
|
+
- dtype must be bfloat16 (matches kernel + buffer dtype).
|
118
|
+
- Total byte size must be 4-byte aligned (hardware requirement).
|
119
|
+
- Payload must be smaller than the symmetric-memory max size.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
True if the symmetric-memory path can handle this tensor.
|
123
|
+
"""
|
124
|
+
if self.disabled:
|
125
|
+
return False
|
126
|
+
if inp.dtype != self.dtype:
|
127
|
+
return False
|
128
|
+
inp_size = inp.numel() * inp.element_size()
|
129
|
+
# enforce 4-byte alignment
|
130
|
+
if inp_size % 4 != 0:
|
131
|
+
return False
|
132
|
+
return inp_size < self.max_size
|
133
|
+
|
134
|
+
def all_reduce(
|
135
|
+
self, inp: torch.Tensor, *, out: Optional[torch.Tensor] = None
|
136
|
+
) -> Optional[torch.Tensor]:
|
137
|
+
"""
|
138
|
+
Perform an in-place sum all-reduce via symmetric memory.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
inp: Input tensor on the target CUDA device (bfloat16).
|
142
|
+
out: Optional output tensor; if omitted, a new tensor is allocated.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
The reduced tensor (same shape as inp), or None if disabled.
|
146
|
+
|
147
|
+
Implementation details:
|
148
|
+
- Stages 'inp' into the symmetric buffer.
|
149
|
+
- Selects 'multimem' or 'two_shot' kernel based on topology.
|
150
|
+
- Writes the result into 'out' and returns it.
|
151
|
+
"""
|
152
|
+
if out is None:
|
153
|
+
out = torch.empty_like(inp)
|
154
|
+
self.buffer[: inp.numel()].copy_(inp.view(-1))
|
155
|
+
if self.world_size in self._WORLD_SIZES_MULTIMEM[self.device_capability]:
|
156
|
+
torch.ops.symm_mem.multimem_all_reduce_(
|
157
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
158
|
+
)
|
159
|
+
else:
|
160
|
+
torch.ops.symm_mem.two_shot_all_reduce_(
|
161
|
+
self.buffer[: inp.numel()], "sum", self.group.group_name
|
162
|
+
)
|
163
|
+
out.copy_(self.buffer[: inp.numel()].view(out.shape))
|
164
|
+
return out
|
@@ -4,7 +4,7 @@
|
|
4
4
|
# Adapted from
|
5
5
|
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
6
6
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
7
|
-
"""
|
7
|
+
"""Distributed state.
|
8
8
|
It takes over the control of the distributed environment from PyTorch.
|
9
9
|
The typical workflow is:
|
10
10
|
|
@@ -53,19 +53,26 @@ from sglang.srt.utils import (
|
|
53
53
|
|
54
54
|
_is_npu = is_npu()
|
55
55
|
_is_cpu = is_cpu()
|
56
|
+
_supports_custom_op = supports_custom_op()
|
56
57
|
|
57
58
|
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
58
59
|
|
59
60
|
|
61
|
+
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
62
|
+
|
63
|
+
# use int value instead of ReduceOp.SUM to support torch compile
|
64
|
+
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
65
|
+
|
66
|
+
|
60
67
|
@dataclass
|
61
68
|
class GraphCaptureContext:
|
62
69
|
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
|
63
70
|
|
64
71
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
72
|
+
@dataclass
|
73
|
+
class P2PWork:
|
74
|
+
work: Optional[torch.distributed.Work]
|
75
|
+
payload: Optional[torch.Tensor]
|
69
76
|
|
70
77
|
|
71
78
|
def _split_tensor_dict(
|
@@ -117,7 +124,7 @@ def _register_group(group: "GroupCoordinator") -> None:
|
|
117
124
|
_groups[group.unique_name] = weakref.ref(group)
|
118
125
|
|
119
126
|
|
120
|
-
if
|
127
|
+
if _supports_custom_op:
|
121
128
|
|
122
129
|
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
|
123
130
|
assert group_name in _groups, f"Group {group_name} is not found."
|
@@ -208,12 +215,14 @@ class GroupCoordinator:
|
|
208
215
|
use_pynccl: bool # a hint of whether to use PyNccl
|
209
216
|
use_pymscclpp: bool # a hint of whether to use PyMsccl
|
210
217
|
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
218
|
+
use_torch_symm_mem: bool # a hint of whether to use SymmMemAllReduce
|
211
219
|
use_message_queue_broadcaster: (
|
212
220
|
bool # a hint of whether to use message queue broadcaster
|
213
221
|
)
|
214
222
|
# communicators are only created for world size > 1
|
215
223
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
216
224
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
225
|
+
symm_mem_comm: Optional[Any] # Symm mem communicator
|
217
226
|
mq_broadcaster: Optional[Any] # shared memory broadcaster
|
218
227
|
|
219
228
|
def __init__(
|
@@ -224,6 +233,7 @@ class GroupCoordinator:
|
|
224
233
|
use_pynccl: bool,
|
225
234
|
use_pymscclpp: bool,
|
226
235
|
use_custom_allreduce: bool,
|
236
|
+
use_torch_symm_mem: bool,
|
227
237
|
use_hpu_communicator: bool,
|
228
238
|
use_xpu_communicator: bool,
|
229
239
|
use_npu_communicator: bool,
|
@@ -272,12 +282,13 @@ class GroupCoordinator:
|
|
272
282
|
self.use_pynccl = use_pynccl
|
273
283
|
self.use_pymscclpp = use_pymscclpp
|
274
284
|
self.use_custom_allreduce = use_custom_allreduce
|
285
|
+
self.use_torch_symm_mem = use_torch_symm_mem
|
275
286
|
self.use_hpu_communicator = use_hpu_communicator
|
276
287
|
self.use_xpu_communicator = use_xpu_communicator
|
277
288
|
self.use_npu_communicator = use_npu_communicator
|
278
289
|
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
279
290
|
|
280
|
-
#
|
291
|
+
# Lazy import to avoid documentation build error
|
281
292
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
282
293
|
CustomAllreduce,
|
283
294
|
)
|
@@ -287,6 +298,9 @@ class GroupCoordinator:
|
|
287
298
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
288
299
|
PyNcclCommunicator,
|
289
300
|
)
|
301
|
+
from sglang.srt.distributed.device_communicators.symm_mem import (
|
302
|
+
SymmMemCommunicator,
|
303
|
+
)
|
290
304
|
|
291
305
|
if is_hip():
|
292
306
|
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
@@ -335,6 +349,13 @@ class GroupCoordinator:
|
|
335
349
|
except Exception as e:
|
336
350
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
337
351
|
|
352
|
+
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
|
353
|
+
if self.use_torch_symm_mem and self.world_size > 1:
|
354
|
+
self.symm_mem_comm = SymmMemCommunicator(
|
355
|
+
group=self.cpu_group,
|
356
|
+
device=self.device,
|
357
|
+
)
|
358
|
+
|
338
359
|
# Create communicator for other hardware backends
|
339
360
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
340
361
|
HpuCommunicator,
|
@@ -439,6 +460,7 @@ class GroupCoordinator:
|
|
439
460
|
# custom allreduce | enabled | enabled |
|
440
461
|
# PyNccl | disabled| enabled |
|
441
462
|
# PyMscclpp | disabled| enabled |
|
463
|
+
# TorchSymmMem | disabled| enabled |
|
442
464
|
# torch.distributed | enabled | disabled|
|
443
465
|
#
|
444
466
|
# Note: When custom quick allreduce is enabled, a runtime check
|
@@ -497,7 +519,7 @@ class GroupCoordinator:
|
|
497
519
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
498
520
|
return input_
|
499
521
|
|
500
|
-
if not
|
522
|
+
if not _supports_custom_op:
|
501
523
|
self._all_reduce_in_place(input_)
|
502
524
|
return input_
|
503
525
|
|
@@ -523,23 +545,29 @@ class GroupCoordinator:
|
|
523
545
|
|
524
546
|
outplace_all_reduce_method = None
|
525
547
|
if (
|
526
|
-
self.qr_comm is not None
|
527
|
-
and not self.qr_comm.disabled
|
528
|
-
and self.qr_comm.should_quick_allreduce(input_)
|
529
|
-
):
|
530
|
-
outplace_all_reduce_method = "qr"
|
531
|
-
elif (
|
532
548
|
self.ca_comm is not None
|
533
549
|
and not self.ca_comm.disabled
|
534
550
|
and self.ca_comm.should_custom_ar(input_)
|
535
551
|
):
|
536
552
|
outplace_all_reduce_method = "ca"
|
553
|
+
elif (
|
554
|
+
self.qr_comm is not None
|
555
|
+
and not self.qr_comm.disabled
|
556
|
+
and self.qr_comm.should_quick_allreduce(input_)
|
557
|
+
):
|
558
|
+
outplace_all_reduce_method = "qr"
|
537
559
|
elif (
|
538
560
|
self.pymscclpp_comm is not None
|
539
561
|
and not self.pymscclpp_comm.disabled
|
540
562
|
and self.pymscclpp_comm.should_mscclpp_allreduce(input_)
|
541
563
|
):
|
542
564
|
outplace_all_reduce_method = "pymscclpp"
|
565
|
+
elif (
|
566
|
+
self.symm_mem_comm is not None
|
567
|
+
and not self.symm_mem_comm.disabled
|
568
|
+
and self.symm_mem_comm.should_symm_mem_allreduce(input_)
|
569
|
+
):
|
570
|
+
outplace_all_reduce_method = "symm_mem"
|
543
571
|
if outplace_all_reduce_method is not None:
|
544
572
|
return torch.ops.sglang.outplace_all_reduce(
|
545
573
|
input_,
|
@@ -553,16 +581,20 @@ class GroupCoordinator:
|
|
553
581
|
def _all_reduce_out_place(
|
554
582
|
self, input_: torch.Tensor, outplace_all_reduce_method: str
|
555
583
|
) -> torch.Tensor:
|
556
|
-
qr_comm = self.qr_comm
|
557
584
|
ca_comm = self.ca_comm
|
585
|
+
qr_comm = self.qr_comm
|
558
586
|
pymscclpp_comm = self.pymscclpp_comm
|
587
|
+
symm_mem_comm = self.symm_mem_comm
|
559
588
|
assert any([qr_comm, ca_comm, pymscclpp_comm])
|
560
|
-
if outplace_all_reduce_method == "
|
561
|
-
assert not qr_comm.disabled
|
562
|
-
out = qr_comm.quick_all_reduce(input_)
|
563
|
-
elif outplace_all_reduce_method == "ca":
|
589
|
+
if outplace_all_reduce_method == "ca":
|
564
590
|
assert not ca_comm.disabled
|
565
591
|
out = ca_comm.custom_all_reduce(input_)
|
592
|
+
elif outplace_all_reduce_method == "qr":
|
593
|
+
assert not qr_comm.disabled
|
594
|
+
out = qr_comm.quick_all_reduce(input_)
|
595
|
+
elif outplace_all_reduce_method == "symm_mem":
|
596
|
+
assert not symm_mem_comm.disabled
|
597
|
+
out = symm_mem_comm.all_reduce(input_)
|
566
598
|
else:
|
567
599
|
assert not pymscclpp_comm.disabled
|
568
600
|
out = pymscclpp_comm.all_reduce(input_)
|
@@ -637,7 +669,7 @@ class GroupCoordinator:
|
|
637
669
|
)
|
638
670
|
|
639
671
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
640
|
-
if _is_npu or not
|
672
|
+
if _is_npu or not _supports_custom_op:
|
641
673
|
self._all_gather_into_tensor(output, input)
|
642
674
|
else:
|
643
675
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
@@ -697,15 +729,13 @@ class GroupCoordinator:
|
|
697
729
|
)
|
698
730
|
|
699
731
|
# All-gather.
|
700
|
-
if input_.is_cpu and is_shm_available(
|
701
|
-
input_.dtype, self.world_size, self.local_size
|
702
|
-
):
|
703
|
-
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
704
|
-
|
705
732
|
if input_.is_cpu:
|
706
|
-
|
707
|
-
|
708
|
-
|
733
|
+
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
734
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
735
|
+
else:
|
736
|
+
torch.distributed.all_gather_into_tensor(
|
737
|
+
output_tensor, input_, group=self.device_group
|
738
|
+
)
|
709
739
|
else:
|
710
740
|
self.all_gather_into_tensor(output_tensor, input_)
|
711
741
|
|
@@ -861,45 +891,63 @@ class GroupCoordinator:
|
|
861
891
|
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
862
892
|
return objs
|
863
893
|
|
864
|
-
def send_object(
|
865
|
-
|
866
|
-
|
894
|
+
def send_object(
|
895
|
+
self,
|
896
|
+
obj: Any,
|
897
|
+
dst: int,
|
898
|
+
async_send: bool = False,
|
899
|
+
) -> List[P2PWork]:
|
900
|
+
"""
|
901
|
+
Send the input object list to the destination rank.
|
902
|
+
This function uses the CPU group for all communications.
|
867
903
|
|
868
|
-
|
904
|
+
TODO: If you want to use GPU communication, please add a new argument (e.g., data_group, group),
|
905
|
+
use other functions (e.g., send), or implement a new function (e.g., send_object_device).
|
906
|
+
|
907
|
+
NOTE: `dst` is the local rank of the destination rank.
|
908
|
+
"""
|
869
909
|
|
910
|
+
assert dst < self.world_size, f"Invalid dst rank ({dst})"
|
870
911
|
assert dst != self.rank_in_group, (
|
871
912
|
"Invalid destination rank. Destination rank is the same "
|
872
913
|
"as the current rank."
|
873
914
|
)
|
915
|
+
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
874
916
|
|
875
917
|
# Serialize object to tensor and get the size as well
|
876
|
-
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
877
|
-
device=torch.cuda.current_device()
|
878
|
-
)
|
879
|
-
|
918
|
+
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
|
880
919
|
size_tensor = torch.tensor(
|
881
|
-
[object_tensor.numel()],
|
882
|
-
dtype=torch.long,
|
883
|
-
device="cpu",
|
920
|
+
[object_tensor.numel()], dtype=torch.long, device="cpu"
|
884
921
|
)
|
922
|
+
|
885
923
|
# Send object size
|
886
|
-
|
924
|
+
p2p_work = []
|
925
|
+
size_work = send_func(
|
926
|
+
size_tensor,
|
927
|
+
self.ranks[dst],
|
928
|
+
group=self.cpu_group,
|
929
|
+
)
|
930
|
+
if async_send:
|
931
|
+
p2p_work.append(P2PWork(size_work, size_tensor))
|
887
932
|
|
888
|
-
|
889
|
-
torch.distributed.send(
|
933
|
+
object_work = send_func(
|
890
934
|
object_tensor,
|
891
|
-
|
892
|
-
group=self.
|
935
|
+
self.ranks[dst],
|
936
|
+
group=self.cpu_group,
|
893
937
|
)
|
938
|
+
if async_send:
|
939
|
+
p2p_work.append(P2PWork(object_work, object_tensor))
|
894
940
|
|
895
|
-
return
|
941
|
+
return p2p_work
|
896
942
|
|
897
|
-
def recv_object(
|
943
|
+
def recv_object(
|
944
|
+
self,
|
945
|
+
src: int,
|
946
|
+
) -> Any:
|
898
947
|
"""Receive the input object list from the source rank."""
|
899
948
|
"""NOTE: `src` is the local rank of the source rank."""
|
900
949
|
|
901
950
|
assert src < self.world_size, f"Invalid src rank ({src})"
|
902
|
-
|
903
951
|
assert (
|
904
952
|
src != self.rank_in_group
|
905
953
|
), "Invalid source rank. Source rank is the same as the current rank."
|
@@ -907,27 +955,25 @@ class GroupCoordinator:
|
|
907
955
|
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
908
956
|
|
909
957
|
# Receive object size
|
910
|
-
|
958
|
+
# We have to use irecv here to make it work for both isend and send.
|
959
|
+
work = torch.distributed.irecv(
|
911
960
|
size_tensor, src=self.ranks[src], group=self.cpu_group
|
912
961
|
)
|
962
|
+
work.wait()
|
913
963
|
|
914
964
|
# Tensor to receive serialized objects into.
|
915
|
-
object_tensor = torch.empty( # type: ignore[call-overload]
|
965
|
+
object_tensor: Any = torch.empty( # type: ignore[call-overload]
|
916
966
|
size_tensor.item(), # type: ignore[arg-type]
|
917
967
|
dtype=torch.uint8,
|
918
|
-
device=
|
968
|
+
device="cpu",
|
919
969
|
)
|
920
970
|
|
921
|
-
|
922
|
-
object_tensor, src=self.ranks[src], group=self.
|
971
|
+
work = torch.distributed.irecv(
|
972
|
+
object_tensor, src=self.ranks[src], group=self.cpu_group
|
923
973
|
)
|
974
|
+
work.wait()
|
924
975
|
|
925
|
-
|
926
|
-
rank_object == rank_size
|
927
|
-
), "Received object sender rank does not match the size sender rank."
|
928
|
-
|
929
|
-
obj = pickle.loads(object_tensor.cpu().numpy())
|
930
|
-
|
976
|
+
obj = pickle.loads(object_tensor.numpy())
|
931
977
|
return obj
|
932
978
|
|
933
979
|
def broadcast_tensor_dict(
|
@@ -1017,12 +1063,13 @@ class GroupCoordinator:
|
|
1017
1063
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
|
1018
1064
|
dst: Optional[int] = None,
|
1019
1065
|
all_gather_group: Optional["GroupCoordinator"] = None,
|
1020
|
-
|
1066
|
+
async_send: bool = False,
|
1067
|
+
) -> Optional[List[P2PWork]]:
|
1021
1068
|
"""Send the input tensor dictionary.
|
1022
1069
|
NOTE: `dst` is the local rank of the source rank.
|
1023
1070
|
"""
|
1024
1071
|
# Bypass the function if we are using only 1 GPU.
|
1025
|
-
if
|
1072
|
+
if self.world_size == 1:
|
1026
1073
|
return tensor_dict
|
1027
1074
|
|
1028
1075
|
all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
|
@@ -1047,7 +1094,10 @@ class GroupCoordinator:
|
|
1047
1094
|
# 1. Superior D2D transfer bandwidth
|
1048
1095
|
# 2. Ability to overlap send and recv operations
|
1049
1096
|
# Thus the net performance gain justifies this approach.
|
1050
|
-
|
1097
|
+
|
1098
|
+
send_func = torch.distributed.isend if async_send else torch.distributed.send
|
1099
|
+
p2p_works = self.send_object(metadata_list, dst=dst, async_send=async_send)
|
1100
|
+
|
1051
1101
|
for tensor in tensor_list:
|
1052
1102
|
if tensor.numel() == 0:
|
1053
1103
|
# Skip sending empty tensors.
|
@@ -1057,15 +1107,11 @@ class GroupCoordinator:
|
|
1057
1107
|
if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
|
1058
1108
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
1059
1109
|
|
1060
|
-
if tensor.is_cpu
|
1061
|
-
|
1062
|
-
|
1063
|
-
|
1064
|
-
|
1065
|
-
else:
|
1066
|
-
# use group for GPU tensors
|
1067
|
-
torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
|
1068
|
-
return None
|
1110
|
+
comm_group = metadata_group if tensor.is_cpu else group
|
1111
|
+
work = send_func(tensor, self.ranks[dst], group=comm_group)
|
1112
|
+
if async_send:
|
1113
|
+
p2p_works.append(P2PWork(work, tensor))
|
1114
|
+
return p2p_works
|
1069
1115
|
|
1070
1116
|
def recv_tensor_dict(
|
1071
1117
|
self,
|
@@ -1111,17 +1157,15 @@ class GroupCoordinator:
|
|
1111
1157
|
orig_shape = tensor.shape
|
1112
1158
|
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
|
1113
1159
|
|
1114
|
-
|
1115
|
-
|
1116
|
-
|
1117
|
-
|
1118
|
-
|
1119
|
-
|
1120
|
-
|
1121
|
-
torch.distributed.recv(tensor, src=self.ranks[src], group=group)
|
1160
|
+
# We have to use irecv here to make it work for both isend and send.
|
1161
|
+
comm_group = metadata_group if tensor.is_cpu else group
|
1162
|
+
work = torch.distributed.irecv(
|
1163
|
+
tensor, src=self.ranks[src], group=comm_group
|
1164
|
+
)
|
1165
|
+
work.wait()
|
1166
|
+
|
1122
1167
|
if use_all_gather:
|
1123
|
-
|
1124
|
-
tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
|
1168
|
+
tensor = all_gather_group.all_gather(tensor, dim=0)
|
1125
1169
|
tensor = tensor.reshape(orig_shape)
|
1126
1170
|
|
1127
1171
|
tensor_dict[key] = tensor
|
@@ -1199,6 +1243,7 @@ def init_world_group(
|
|
1199
1243
|
use_pynccl=False,
|
1200
1244
|
use_pymscclpp=False,
|
1201
1245
|
use_custom_allreduce=False,
|
1246
|
+
use_torch_symm_mem=False,
|
1202
1247
|
use_hpu_communicator=False,
|
1203
1248
|
use_xpu_communicator=False,
|
1204
1249
|
use_npu_communicator=False,
|
@@ -1214,11 +1259,14 @@ def init_model_parallel_group(
|
|
1214
1259
|
use_message_queue_broadcaster: bool = False,
|
1215
1260
|
group_name: Optional[str] = None,
|
1216
1261
|
use_mscclpp_allreduce: Optional[bool] = None,
|
1262
|
+
use_symm_mem_allreduce: Optional[bool] = None,
|
1217
1263
|
) -> GroupCoordinator:
|
1218
1264
|
if use_custom_allreduce is None:
|
1219
1265
|
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
|
1220
1266
|
if use_mscclpp_allreduce is None:
|
1221
1267
|
use_mscclpp_allreduce = _ENABLE_MSCCLPP_ALL_REDUCE
|
1268
|
+
if use_symm_mem_allreduce is None:
|
1269
|
+
use_symm_mem_allreduce = _ENABLE_SYMM_MEM_ALL_REDUCE
|
1222
1270
|
return GroupCoordinator(
|
1223
1271
|
group_ranks=group_ranks,
|
1224
1272
|
local_rank=local_rank,
|
@@ -1226,6 +1274,7 @@ def init_model_parallel_group(
|
|
1226
1274
|
use_pynccl=not _is_npu,
|
1227
1275
|
use_pymscclpp=use_mscclpp_allreduce,
|
1228
1276
|
use_custom_allreduce=use_custom_allreduce,
|
1277
|
+
use_torch_symm_mem=use_symm_mem_allreduce,
|
1229
1278
|
use_hpu_communicator=True,
|
1230
1279
|
use_xpu_communicator=True,
|
1231
1280
|
use_npu_communicator=True,
|
@@ -1311,6 +1360,7 @@ logger = logging.getLogger(__name__)
|
|
1311
1360
|
|
1312
1361
|
_ENABLE_CUSTOM_ALL_REDUCE = True
|
1313
1362
|
_ENABLE_MSCCLPP_ALL_REDUCE = False
|
1363
|
+
_ENABLE_SYMM_MEM_ALL_REDUCE = False
|
1314
1364
|
|
1315
1365
|
|
1316
1366
|
def set_custom_all_reduce(enable: bool):
|
@@ -1323,6 +1373,11 @@ def set_mscclpp_all_reduce(enable: bool):
|
|
1323
1373
|
_ENABLE_MSCCLPP_ALL_REDUCE = enable
|
1324
1374
|
|
1325
1375
|
|
1376
|
+
def set_symm_mem_all_reduce(enable: bool):
|
1377
|
+
global _ENABLE_SYMM_MEM_ALL_REDUCE
|
1378
|
+
_ENABLE_SYMM_MEM_ALL_REDUCE = enable
|
1379
|
+
|
1380
|
+
|
1326
1381
|
def init_distributed_environment(
|
1327
1382
|
world_size: int = -1,
|
1328
1383
|
rank: int = -1,
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
|
|
47
47
|
)
|
48
48
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
49
49
|
from sglang.srt.managers.io_struct import (
|
50
|
+
DestroyWeightsUpdateGroupReqInput,
|
50
51
|
EmbeddingReqInput,
|
51
52
|
GenerateReqInput,
|
52
53
|
GetWeightsByNameReqInput,
|
@@ -433,6 +434,19 @@ class Engine(EngineBase):
|
|
433
434
|
self.tokenizer_manager.init_weights_update_group(obj, None)
|
434
435
|
)
|
435
436
|
|
437
|
+
def destroy_weights_update_group(
|
438
|
+
self,
|
439
|
+
group_name: str,
|
440
|
+
):
|
441
|
+
"""Destroy parameter update group."""
|
442
|
+
obj = DestroyWeightsUpdateGroupReqInput(
|
443
|
+
group_name=group_name,
|
444
|
+
)
|
445
|
+
loop = asyncio.get_event_loop()
|
446
|
+
return loop.run_until_complete(
|
447
|
+
self.tokenizer_manager.destroy_weights_update_group(obj, None)
|
448
|
+
)
|
449
|
+
|
436
450
|
def update_weights_from_distributed(
|
437
451
|
self,
|
438
452
|
names: list[str],
|
@@ -666,6 +680,13 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
666
680
|
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
667
681
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
668
682
|
|
683
|
+
if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
|
684
|
+
# Default to warning level, to avoid too many logs
|
685
|
+
os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
|
686
|
+
if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
|
687
|
+
# Need to set log to console, otherwise the log level won't take effect
|
688
|
+
os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
|
689
|
+
|
669
690
|
# Can also be passed as argument
|
670
691
|
os.environ["SGLANG_RUN_ID"] = (
|
671
692
|
f"sglang-run-{time.time()}-{random.randint(0, 100000000)}"
|
@@ -682,7 +703,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
682
703
|
if server_args.attention_backend == "flashinfer":
|
683
704
|
assert_pkg_version(
|
684
705
|
"flashinfer_python",
|
685
|
-
"0.
|
706
|
+
"0.4.0rc3",
|
686
707
|
"Please uninstall the old version and "
|
687
708
|
"reinstall the latest version by following the instructions "
|
688
709
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -690,7 +711,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
690
711
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
691
712
|
assert_pkg_version(
|
692
713
|
"sgl-kernel",
|
693
|
-
"0.3.
|
714
|
+
"0.3.14",
|
694
715
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
695
716
|
)
|
696
717
|
|
@@ -791,7 +812,6 @@ def _launch_subprocesses(
|
|
791
812
|
pp_rank,
|
792
813
|
None,
|
793
814
|
writer,
|
794
|
-
None,
|
795
815
|
),
|
796
816
|
)
|
797
817
|
|