sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -25,17 +25,19 @@ import time
|
|
25
25
|
from collections import defaultdict
|
26
26
|
from dataclasses import dataclass
|
27
27
|
from typing import List, Optional, Tuple, Union
|
28
|
-
from urllib.parse import urlparse
|
29
28
|
|
30
|
-
import requests
|
31
29
|
import torch
|
32
30
|
import torch.distributed as dist
|
33
31
|
|
34
32
|
from sglang.srt.configs.device_config import DeviceConfig
|
35
33
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
36
|
-
from sglang.srt.configs.model_config import
|
34
|
+
from sglang.srt.configs.model_config import (
|
35
|
+
AttentionArch,
|
36
|
+
ModelConfig,
|
37
|
+
get_nsa_index_head_dim,
|
38
|
+
is_deepseek_nsa,
|
39
|
+
)
|
37
40
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
38
|
-
from sglang.srt.connector import ConnectorType
|
39
41
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
40
42
|
from sglang.srt.distributed import (
|
41
43
|
get_pp_group,
|
@@ -45,6 +47,7 @@ from sglang.srt.distributed import (
|
|
45
47
|
initialize_model_parallel,
|
46
48
|
set_custom_all_reduce,
|
47
49
|
set_mscclpp_all_reduce,
|
50
|
+
set_symm_mem_all_reduce,
|
48
51
|
)
|
49
52
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
50
53
|
from sglang.srt.eplb.eplb_manager import EPLBManager
|
@@ -60,6 +63,10 @@ from sglang.srt.eplb.expert_location import (
|
|
60
63
|
set_global_expert_location_metadata,
|
61
64
|
)
|
62
65
|
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
66
|
+
from sglang.srt.layers.attention.attention_registry import (
|
67
|
+
ATTENTION_BACKENDS,
|
68
|
+
attn_backend_wrapper,
|
69
|
+
)
|
63
70
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
64
71
|
from sglang.srt.layers.dp_attention import (
|
65
72
|
get_attention_tp_group,
|
@@ -94,6 +101,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
94
101
|
HybridReqToTokenPool,
|
95
102
|
MHATokenToKVPool,
|
96
103
|
MLATokenToKVPool,
|
104
|
+
NSATokenToKVPool,
|
97
105
|
ReqToTokenPool,
|
98
106
|
SWAKVPool,
|
99
107
|
)
|
@@ -103,6 +111,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
103
111
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
104
112
|
from sglang.srt.model_loader import get_model
|
105
113
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
114
|
+
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
115
|
+
trigger_init_weights_send_group_for_remote_instance_request,
|
116
|
+
)
|
106
117
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
107
118
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
108
119
|
from sglang.srt.offloader import (
|
@@ -110,10 +121,6 @@ from sglang.srt.offloader import (
|
|
110
121
|
get_offloader,
|
111
122
|
set_offloader,
|
112
123
|
)
|
113
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
114
|
-
from sglang.srt.remote_instance_weight_loader_utils import (
|
115
|
-
trigger_init_weights_send_group_for_remote_instance_request,
|
116
|
-
)
|
117
124
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
118
125
|
from sglang.srt.server_args import ServerArgs
|
119
126
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -127,7 +134,6 @@ from sglang.srt.utils import (
|
|
127
134
|
get_bool_env_var,
|
128
135
|
get_cpu_ids_by_node,
|
129
136
|
init_custom_process_group,
|
130
|
-
is_blackwell,
|
131
137
|
is_fa3_default_architecture,
|
132
138
|
is_flashinfer_available,
|
133
139
|
is_hip,
|
@@ -135,16 +141,38 @@ from sglang.srt.utils import (
|
|
135
141
|
is_no_spec_infer_or_topk_one,
|
136
142
|
is_npu,
|
137
143
|
is_sm100_supported,
|
144
|
+
log_info_on_rank0,
|
138
145
|
monkey_patch_p2p_access_check,
|
139
146
|
monkey_patch_vllm_gguf_config,
|
140
|
-
parse_connector_type,
|
141
147
|
set_cuda_arch,
|
148
|
+
slow_rank_detector,
|
142
149
|
)
|
150
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
143
151
|
from sglang.srt.weight_sync.tensor_bucket import (
|
144
152
|
FlattenedTensorBucket,
|
145
153
|
FlattenedTensorMetadata,
|
146
154
|
)
|
147
155
|
|
156
|
+
MLA_ATTENTION_BACKENDS = [
|
157
|
+
"aiter",
|
158
|
+
"flashinfer",
|
159
|
+
"fa3",
|
160
|
+
"fa4",
|
161
|
+
"triton",
|
162
|
+
"flashmla",
|
163
|
+
"cutlass_mla",
|
164
|
+
"trtllm_mla",
|
165
|
+
"ascend",
|
166
|
+
"nsa",
|
167
|
+
]
|
168
|
+
|
169
|
+
|
170
|
+
def add_mla_attention_backend(backend_name):
|
171
|
+
if backend_name not in MLA_ATTENTION_BACKENDS:
|
172
|
+
MLA_ATTENTION_BACKENDS.append(backend_name)
|
173
|
+
logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
|
174
|
+
|
175
|
+
|
148
176
|
_is_hip = is_hip()
|
149
177
|
_is_npu = is_npu()
|
150
178
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -158,6 +186,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
158
186
|
logger = logging.getLogger(__name__)
|
159
187
|
|
160
188
|
|
189
|
+
if _is_npu:
|
190
|
+
import torch_npu
|
191
|
+
|
192
|
+
torch.npu.config.allow_internal_format = True
|
193
|
+
torch_npu.npu.set_compile_mode(jit_compile=False)
|
194
|
+
|
195
|
+
|
161
196
|
class RankZeroFilter(logging.Filter):
|
162
197
|
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
163
198
|
|
@@ -252,6 +287,9 @@ class ModelRunner:
|
|
252
287
|
# CPU offload
|
253
288
|
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
254
289
|
|
290
|
+
if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
|
291
|
+
slow_rank_detector.execute()
|
292
|
+
|
255
293
|
# Update deep gemm configure
|
256
294
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
257
295
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
@@ -319,7 +357,6 @@ class ModelRunner:
|
|
319
357
|
if self.is_hybrid_gdn:
|
320
358
|
logger.warning("Hybrid GDN model detected, disable radix cache")
|
321
359
|
self.server_args.disable_radix_cache = True
|
322
|
-
self.server_args.attention_backend = "hybrid_linear_attn"
|
323
360
|
if self.server_args.max_mamba_cache_size is None:
|
324
361
|
if self.server_args.max_running_requests is not None:
|
325
362
|
self.server_args.max_mamba_cache_size = (
|
@@ -385,6 +422,12 @@ class ModelRunner:
|
|
385
422
|
)
|
386
423
|
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
387
424
|
|
425
|
+
# Enable batch invariant mode
|
426
|
+
if server_args.enable_deterministic_inference:
|
427
|
+
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
|
428
|
+
|
429
|
+
enable_batch_invariant_mode()
|
430
|
+
|
388
431
|
# Init memory pool and attention backends
|
389
432
|
self.init_memory_pool(
|
390
433
|
min_per_gpu_memory,
|
@@ -496,9 +539,7 @@ class ModelRunner:
|
|
496
539
|
elif _is_hip:
|
497
540
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
498
541
|
# TODO current aiter only support head number 16 or 128 head number
|
499
|
-
if
|
500
|
-
head_num == 128 or head_num == 16
|
501
|
-
) and self.spec_algorithm.is_none():
|
542
|
+
if head_num == 128 or head_num == 16:
|
502
543
|
server_args.attention_backend = "aiter"
|
503
544
|
else:
|
504
545
|
server_args.attention_backend = "triton"
|
@@ -511,16 +552,7 @@ class ModelRunner:
|
|
511
552
|
)
|
512
553
|
elif self.use_mla_backend:
|
513
554
|
if server_args.device != "cpu":
|
514
|
-
if server_args.attention_backend in
|
515
|
-
"aiter",
|
516
|
-
"flashinfer",
|
517
|
-
"fa3",
|
518
|
-
"triton",
|
519
|
-
"flashmla",
|
520
|
-
"cutlass_mla",
|
521
|
-
"trtllm_mla",
|
522
|
-
"ascend",
|
523
|
-
]:
|
555
|
+
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
|
524
556
|
logger.info(
|
525
557
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
526
558
|
)
|
@@ -562,18 +594,6 @@ class ModelRunner:
|
|
562
594
|
if not self.use_mla_backend:
|
563
595
|
server_args.disable_chunked_prefix_cache = True
|
564
596
|
|
565
|
-
# TODO(kaixih@nvidia): remove this once we have a better solution for DP attention.
|
566
|
-
# For more details, see: https://github.com/sgl-project/sglang/issues/8616
|
567
|
-
elif (
|
568
|
-
self.dp_size > 1
|
569
|
-
and is_sm100_supported()
|
570
|
-
and server_args.attention_backend != "triton"
|
571
|
-
and server_args.attention_backend == "trtllm_mla"
|
572
|
-
):
|
573
|
-
logger.info(
|
574
|
-
"Disable chunked prefix cache when dp size > 1 and attention backend is not triton."
|
575
|
-
)
|
576
|
-
server_args.disable_chunked_prefix_cache = True
|
577
597
|
if not server_args.disable_chunked_prefix_cache:
|
578
598
|
logger.info("Chunked prefix cache is turned on.")
|
579
599
|
|
@@ -599,7 +619,7 @@ class ModelRunner:
|
|
599
619
|
server_args.hicache_io_backend = "direct"
|
600
620
|
logger.warning(
|
601
621
|
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
602
|
-
|
622
|
+
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
603
623
|
)
|
604
624
|
|
605
625
|
def init_torch_distributed(self):
|
@@ -634,6 +654,7 @@ class ModelRunner:
|
|
634
654
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
635
655
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
636
656
|
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
657
|
+
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
|
637
658
|
|
638
659
|
if not self.is_draft_worker:
|
639
660
|
if self.device == "cpu":
|
@@ -730,6 +751,10 @@ class ModelRunner:
|
|
730
751
|
load_format=self.server_args.load_format,
|
731
752
|
download_dir=self.server_args.download_dir,
|
732
753
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
754
|
+
tp_rank=self.tp_rank,
|
755
|
+
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
756
|
+
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
757
|
+
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
733
758
|
)
|
734
759
|
if self.device == "cpu":
|
735
760
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
@@ -757,7 +782,10 @@ class ModelRunner:
|
|
757
782
|
monkey_patch_vllm_parallel_state()
|
758
783
|
monkey_patch_isinstance_for_vllm_base_layer()
|
759
784
|
|
760
|
-
with self.memory_saver_adapter.region(
|
785
|
+
with self.memory_saver_adapter.region(
|
786
|
+
GPU_MEMORY_TYPE_WEIGHTS,
|
787
|
+
enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
|
788
|
+
):
|
761
789
|
self.model = get_model(
|
762
790
|
model_config=self.model_config,
|
763
791
|
load_config=self.load_config,
|
@@ -1035,6 +1063,19 @@ class ModelRunner:
|
|
1035
1063
|
logger.error(message)
|
1036
1064
|
return False, message
|
1037
1065
|
|
1066
|
+
def destroy_weights_update_group(self, group_name):
|
1067
|
+
try:
|
1068
|
+
if group_name in self._model_update_group:
|
1069
|
+
pg = self._model_update_group.pop(group_name)
|
1070
|
+
torch.distributed.destroy_process_group(pg)
|
1071
|
+
return True, "Succeeded to destroy custom process group."
|
1072
|
+
else:
|
1073
|
+
return False, "The group to be destroyed does not exist."
|
1074
|
+
except Exception as e:
|
1075
|
+
message = f"Failed to destroy custom process group: {e}."
|
1076
|
+
logger.error(message)
|
1077
|
+
return False, message
|
1078
|
+
|
1038
1079
|
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
1039
1080
|
"""
|
1040
1081
|
Update specific parameter in the model weights online
|
@@ -1072,7 +1113,7 @@ class ModelRunner:
|
|
1072
1113
|
handle.wait()
|
1073
1114
|
|
1074
1115
|
self.model.load_weights(weights)
|
1075
|
-
return True,
|
1116
|
+
return True, "Succeeded to update parameter online."
|
1076
1117
|
|
1077
1118
|
except Exception as e:
|
1078
1119
|
error_msg = (
|
@@ -1176,6 +1217,7 @@ class ModelRunner:
|
|
1176
1217
|
max_lora_rank=self.server_args.max_lora_rank,
|
1177
1218
|
target_modules=self.server_args.lora_target_modules,
|
1178
1219
|
lora_paths=self.server_args.lora_paths,
|
1220
|
+
server_args=self.server_args,
|
1179
1221
|
)
|
1180
1222
|
|
1181
1223
|
def load_lora_adapter(self, lora_ref: LoRARef):
|
@@ -1260,6 +1302,7 @@ class ModelRunner:
|
|
1260
1302
|
return self.model_config.hf_config.architectures[0] in [
|
1261
1303
|
"Qwen3NextForCausalLM",
|
1262
1304
|
"Qwen3NextForCausalLMMTP",
|
1305
|
+
"FalconH1ForCausalLM",
|
1263
1306
|
]
|
1264
1307
|
|
1265
1308
|
def set_num_token_hybrid(self):
|
@@ -1352,7 +1395,18 @@ class ModelRunner:
|
|
1352
1395
|
):
|
1353
1396
|
# Determine the kv cache dtype
|
1354
1397
|
if self.server_args.kv_cache_dtype == "auto":
|
1355
|
-
|
1398
|
+
quant_config = getattr(self.model, "quant_config", None)
|
1399
|
+
kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
|
1400
|
+
if (
|
1401
|
+
isinstance(kv_cache_quant_algo, str)
|
1402
|
+
and kv_cache_quant_algo.upper() == "FP8"
|
1403
|
+
):
|
1404
|
+
if _is_hip:
|
1405
|
+
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
1406
|
+
else:
|
1407
|
+
self.kv_cache_dtype = torch.float8_e4m3fn
|
1408
|
+
else:
|
1409
|
+
self.kv_cache_dtype = self.dtype
|
1356
1410
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
1357
1411
|
if _is_hip: # Using natively supported format
|
1358
1412
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
@@ -1368,6 +1422,8 @@ class ModelRunner:
|
|
1368
1422
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
1369
1423
|
)
|
1370
1424
|
|
1425
|
+
log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
|
1426
|
+
|
1371
1427
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
1372
1428
|
if SGLANG_CI_SMALL_KV_SIZE:
|
1373
1429
|
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
@@ -1385,7 +1441,7 @@ class ModelRunner:
|
|
1385
1441
|
if self.is_hybrid_gdn:
|
1386
1442
|
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
1387
1443
|
|
1388
|
-
if
|
1444
|
+
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
1389
1445
|
if self.is_draft_worker:
|
1390
1446
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
1391
1447
|
max_num_reqs = self.server_args.max_num_reqs
|
@@ -1438,7 +1494,8 @@ class ModelRunner:
|
|
1438
1494
|
|
1439
1495
|
if self.max_total_num_tokens <= 0:
|
1440
1496
|
raise RuntimeError(
|
1441
|
-
"Not enough memory. Please try to increase --mem-fraction-static."
|
1497
|
+
f"Not enough memory. Please try to increase --mem-fraction-static. "
|
1498
|
+
f"Current value: {self.server_args.mem_fraction_static=}"
|
1442
1499
|
)
|
1443
1500
|
|
1444
1501
|
# Initialize req_to_token_pool
|
@@ -1497,6 +1554,7 @@ class ModelRunner:
|
|
1497
1554
|
assert self.is_draft_worker
|
1498
1555
|
|
1499
1556
|
# Initialize token_to_kv_pool
|
1557
|
+
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
|
1500
1558
|
if self.server_args.attention_backend == "ascend":
|
1501
1559
|
if self.use_mla_backend:
|
1502
1560
|
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
@@ -1505,6 +1563,7 @@ class ModelRunner:
|
|
1505
1563
|
dtype=self.kv_cache_dtype,
|
1506
1564
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1507
1565
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1566
|
+
index_head_dim=self.model_config.index_head_dim,
|
1508
1567
|
layer_num=self.num_effective_layers,
|
1509
1568
|
device=self.device,
|
1510
1569
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
@@ -1524,7 +1583,22 @@ class ModelRunner:
|
|
1524
1583
|
device=self.device,
|
1525
1584
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1526
1585
|
)
|
1586
|
+
elif self.use_mla_backend and is_nsa_model:
|
1587
|
+
self.token_to_kv_pool = NSATokenToKVPool(
|
1588
|
+
self.max_total_num_tokens,
|
1589
|
+
page_size=self.page_size,
|
1590
|
+
dtype=self.kv_cache_dtype,
|
1591
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
1592
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1593
|
+
layer_num=self.num_effective_layers,
|
1594
|
+
device=self.device,
|
1595
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1596
|
+
start_layer=self.start_layer,
|
1597
|
+
end_layer=self.end_layer,
|
1598
|
+
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
|
1599
|
+
)
|
1527
1600
|
elif self.use_mla_backend:
|
1601
|
+
assert not is_nsa_model
|
1528
1602
|
self.token_to_kv_pool = MLATokenToKVPool(
|
1529
1603
|
self.max_total_num_tokens,
|
1530
1604
|
page_size=self.page_size,
|
@@ -1568,7 +1642,7 @@ class ModelRunner:
|
|
1568
1642
|
)
|
1569
1643
|
elif self.is_hybrid_gdn:
|
1570
1644
|
self.token_to_kv_pool = HybridLinearKVPool(
|
1571
|
-
page_size=self.page_size
|
1645
|
+
page_size=self.page_size,
|
1572
1646
|
size=self.max_total_num_tokens,
|
1573
1647
|
dtype=self.kv_cache_dtype,
|
1574
1648
|
head_num=self.model_config.get_num_kv_heads(
|
@@ -1603,10 +1677,9 @@ class ModelRunner:
|
|
1603
1677
|
# Initialize token_to_kv_pool_allocator
|
1604
1678
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1605
1679
|
if self.token_to_kv_pool_allocator is None:
|
1606
|
-
if _is_npu and
|
1607
|
-
"ascend"
|
1608
|
-
|
1609
|
-
]:
|
1680
|
+
if _is_npu and (
|
1681
|
+
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
|
1682
|
+
):
|
1610
1683
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1611
1684
|
self.max_total_num_tokens,
|
1612
1685
|
page_size=self.page_size,
|
@@ -1700,8 +1773,8 @@ class ModelRunner:
|
|
1700
1773
|
f"prefill_backend={self.prefill_attention_backend_str}."
|
1701
1774
|
)
|
1702
1775
|
logger.warning(
|
1703
|
-
|
1704
|
-
|
1776
|
+
"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
1777
|
+
"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
1705
1778
|
)
|
1706
1779
|
else:
|
1707
1780
|
attn_backend = self._get_attention_backend_from_str(
|
@@ -1717,140 +1790,10 @@ class ModelRunner:
|
|
1717
1790
|
return attn_backend
|
1718
1791
|
|
1719
1792
|
def _get_attention_backend_from_str(self, backend_str: str):
|
1720
|
-
if backend_str
|
1721
|
-
if not self.use_mla_backend:
|
1722
|
-
from sglang.srt.layers.attention.flashinfer_backend import (
|
1723
|
-
FlashInferAttnBackend,
|
1724
|
-
)
|
1725
|
-
|
1726
|
-
# Init streams
|
1727
|
-
if self.server_args.speculative_algorithm == "EAGLE":
|
1728
|
-
if (
|
1729
|
-
not hasattr(self, "plan_stream_for_flashinfer")
|
1730
|
-
or not self.plan_stream_for_flashinfer
|
1731
|
-
):
|
1732
|
-
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
1733
|
-
return FlashInferAttnBackend(self)
|
1734
|
-
else:
|
1735
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
1736
|
-
FlashInferMLAAttnBackend,
|
1737
|
-
)
|
1738
|
-
|
1739
|
-
return FlashInferMLAAttnBackend(self)
|
1740
|
-
elif backend_str == "aiter":
|
1741
|
-
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1742
|
-
|
1743
|
-
return AiterAttnBackend(self)
|
1744
|
-
elif self.server_args.attention_backend == "wave":
|
1745
|
-
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
1746
|
-
|
1747
|
-
return WaveAttnBackend(self)
|
1748
|
-
elif backend_str == "ascend":
|
1749
|
-
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1750
|
-
|
1751
|
-
return AscendAttnBackend(self)
|
1752
|
-
elif backend_str == "triton":
|
1753
|
-
assert not self.model_config.is_encoder_decoder, (
|
1754
|
-
"Cross attention is not supported in the triton attention backend. "
|
1755
|
-
"Please use `--attention-backend flashinfer`."
|
1756
|
-
)
|
1757
|
-
if self.server_args.enable_double_sparsity:
|
1758
|
-
from sglang.srt.layers.attention.double_sparsity_backend import (
|
1759
|
-
DoubleSparseAttnBackend,
|
1760
|
-
)
|
1761
|
-
|
1762
|
-
return DoubleSparseAttnBackend(self)
|
1763
|
-
else:
|
1764
|
-
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1765
|
-
|
1766
|
-
return TritonAttnBackend(self)
|
1767
|
-
elif backend_str == "torch_native":
|
1768
|
-
from sglang.srt.layers.attention.torch_native_backend import (
|
1769
|
-
TorchNativeAttnBackend,
|
1770
|
-
)
|
1771
|
-
|
1772
|
-
return TorchNativeAttnBackend(self)
|
1773
|
-
elif backend_str == "flashmla":
|
1774
|
-
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
1775
|
-
|
1776
|
-
return FlashMLABackend(self)
|
1777
|
-
elif backend_str == "fa3":
|
1778
|
-
assert (
|
1779
|
-
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
1780
|
-
) or torch.cuda.get_device_capability()[0] == 9, (
|
1781
|
-
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
1782
|
-
"Please use `--attention-backend flashinfer`."
|
1783
|
-
)
|
1784
|
-
from sglang.srt.layers.attention.flashattention_backend import (
|
1785
|
-
FlashAttentionBackend,
|
1786
|
-
)
|
1787
|
-
|
1788
|
-
return FlashAttentionBackend(self)
|
1789
|
-
elif backend_str == "cutlass_mla":
|
1790
|
-
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
1791
|
-
CutlassMLABackend,
|
1792
|
-
)
|
1793
|
-
|
1794
|
-
return CutlassMLABackend(self)
|
1795
|
-
elif backend_str == "trtllm_mla":
|
1796
|
-
if not self.use_mla_backend:
|
1797
|
-
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1798
|
-
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1799
|
-
|
1800
|
-
return TRTLLMMLABackend(self)
|
1801
|
-
elif backend_str == "trtllm_mha":
|
1802
|
-
if self.use_mla_backend:
|
1803
|
-
raise ValueError(
|
1804
|
-
"trtllm_mha backend can only be used with non-MLA models."
|
1805
|
-
)
|
1806
|
-
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
1807
|
-
TRTLLMHAAttnBackend,
|
1808
|
-
)
|
1809
|
-
|
1810
|
-
return TRTLLMHAAttnBackend(self)
|
1811
|
-
elif backend_str == "intel_amx":
|
1812
|
-
from sglang.srt.layers.attention.intel_amx_backend import (
|
1813
|
-
IntelAMXAttnBackend,
|
1814
|
-
)
|
1815
|
-
|
1816
|
-
return IntelAMXAttnBackend(self)
|
1817
|
-
elif backend_str == "dual_chunk_flash_attn":
|
1818
|
-
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1819
|
-
DualChunkFlashAttentionBackend,
|
1820
|
-
)
|
1821
|
-
|
1822
|
-
return DualChunkFlashAttentionBackend(self)
|
1823
|
-
elif backend_str == "hybrid_linear_attn":
|
1824
|
-
assert (
|
1825
|
-
self.is_hybrid_gdn
|
1826
|
-
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
1827
|
-
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
1828
|
-
HybridLinearAttnBackend,
|
1829
|
-
MambaAttnBackend,
|
1830
|
-
)
|
1831
|
-
|
1832
|
-
if _is_npu:
|
1833
|
-
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1834
|
-
|
1835
|
-
full_attn_backend = AscendAttnBackend(self)
|
1836
|
-
elif is_blackwell():
|
1837
|
-
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1838
|
-
|
1839
|
-
full_attn_backend = TritonAttnBackend(self)
|
1840
|
-
else:
|
1841
|
-
from sglang.srt.layers.attention.flashattention_backend import (
|
1842
|
-
FlashAttentionBackend,
|
1843
|
-
)
|
1844
|
-
|
1845
|
-
full_attn_backend = FlashAttentionBackend(self)
|
1846
|
-
|
1847
|
-
linear_attn_backend = MambaAttnBackend(self)
|
1848
|
-
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
|
1849
|
-
return HybridLinearAttnBackend(
|
1850
|
-
full_attn_backend, linear_attn_backend, full_attn_layers
|
1851
|
-
)
|
1852
|
-
else:
|
1793
|
+
if backend_str not in ATTENTION_BACKENDS:
|
1853
1794
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1795
|
+
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
|
1796
|
+
return attn_backend_wrapper(self, full_attention_backend)
|
1854
1797
|
|
1855
1798
|
def init_double_sparsity_channel_config(self, selected_channel):
|
1856
1799
|
selected_channel = "." + selected_channel + "_proj"
|
@@ -2147,7 +2090,6 @@ class ModelRunner:
|
|
2147
2090
|
)
|
2148
2091
|
|
2149
2092
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
2150
|
-
|
2151
2093
|
# Sample the next tokens
|
2152
2094
|
next_token_ids = self.sampler(
|
2153
2095
|
logits_output,
|
@@ -2155,6 +2097,12 @@ class ModelRunner:
|
|
2155
2097
|
forward_batch.return_logprob,
|
2156
2098
|
forward_batch.top_logprobs_nums,
|
2157
2099
|
forward_batch.token_ids_logprobs,
|
2100
|
+
# For prefill, we only use the position of the last token.
|
2101
|
+
(
|
2102
|
+
forward_batch.positions
|
2103
|
+
if forward_batch.forward_mode.is_decode()
|
2104
|
+
else forward_batch.seq_lens - 1
|
2105
|
+
),
|
2158
2106
|
)
|
2159
2107
|
return next_token_ids
|
2160
2108
|
|
@@ -19,8 +19,10 @@ import logging
|
|
19
19
|
import threading
|
20
20
|
from typing import TYPE_CHECKING, Optional, Union
|
21
21
|
|
22
|
+
import numpy as np
|
22
23
|
import torch
|
23
24
|
|
25
|
+
from sglang.srt.configs.model_config import AttentionArch
|
24
26
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
25
27
|
|
26
28
|
logger = logging.getLogger(__name__)
|
@@ -73,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
|
|
73
75
|
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
74
76
|
|
75
77
|
# Replay
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
78
|
+
if self.model_runner.model_config.index_head_dim is None:
|
79
|
+
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
|
80
|
+
self.bs - self.raw_bs
|
81
|
+
)
|
82
|
+
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
83
|
+
thread.start()
|
84
|
+
self.graphs[self.bs].replay()
|
85
|
+
thread.join()
|
86
|
+
else:
|
87
|
+
self.graphs[self.bs].replay()
|
81
88
|
|
82
89
|
output = self.output_buffers[self.bs]
|
83
90
|
if isinstance(output, LogitsProcessorOutput):
|
@@ -54,6 +54,9 @@ from sglang.srt.distributed import (
|
|
54
54
|
get_tensor_model_parallel_rank,
|
55
55
|
get_tensor_model_parallel_world_size,
|
56
56
|
)
|
57
|
+
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
58
|
+
trigger_transferring_weights_request,
|
59
|
+
)
|
57
60
|
from sglang.srt.model_loader.utils import (
|
58
61
|
get_model_architecture,
|
59
62
|
post_load_weights,
|
@@ -77,9 +80,6 @@ from sglang.srt.model_loader.weight_utils import (
|
|
77
80
|
safetensors_weights_iterator,
|
78
81
|
set_runai_streamer_env,
|
79
82
|
)
|
80
|
-
from sglang.srt.remote_instance_weight_loader_utils import (
|
81
|
-
trigger_transferring_weights_request,
|
82
|
-
)
|
83
83
|
from sglang.srt.utils import (
|
84
84
|
get_bool_env_var,
|
85
85
|
get_device_capability,
|
@@ -206,7 +206,10 @@ def _initialize_model(
|
|
206
206
|
if _is_npu:
|
207
207
|
packed_modules_mapping.update(
|
208
208
|
{
|
209
|
-
"visual": {
|
209
|
+
"visual": {
|
210
|
+
"qkv_proj": ["qkv"],
|
211
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
212
|
+
},
|
210
213
|
"vision_model": {
|
211
214
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
212
215
|
"proj": ["out_proj"],
|
@@ -1417,7 +1420,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
|
|
1417
1420
|
f"load format {load_config.load_format}"
|
1418
1421
|
)
|
1419
1422
|
|
1420
|
-
model_weights = f"instance://{
|
1423
|
+
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
|
1421
1424
|
|
1422
1425
|
with set_default_torch_dtype(model_config.dtype):
|
1423
1426
|
with torch.device(device_config.device):
|
@@ -1439,11 +1442,12 @@ class RemoteInstanceModelLoader(BaseModelLoader):
|
|
1439
1442
|
def load_model_from_remote_instance(
|
1440
1443
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1441
1444
|
) -> nn.Module:
|
1445
|
+
load_config = self.load_config
|
1442
1446
|
instance_ip = socket.gethostbyname(socket.gethostname())
|
1443
1447
|
start_build_group_tic = time.time()
|
1444
1448
|
client.build_group(
|
1445
1449
|
gpu_id=device_config.gpu_id,
|
1446
|
-
tp_rank=
|
1450
|
+
tp_rank=load_config.tp_rank,
|
1447
1451
|
instance_ip=instance_ip,
|
1448
1452
|
)
|
1449
1453
|
torch.cuda.synchronize()
|
@@ -1452,13 +1456,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
|
|
1452
1456
|
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
1453
1457
|
)
|
1454
1458
|
|
1455
|
-
if
|
1459
|
+
if load_config.tp_rank == 0:
|
1456
1460
|
t = threading.Thread(
|
1457
1461
|
target=trigger_transferring_weights_request,
|
1458
1462
|
args=(
|
1459
|
-
|
1460
|
-
|
1461
|
-
|
1463
|
+
load_config.remote_instance_weight_loader_seed_instance_ip,
|
1464
|
+
load_config.remote_instance_weight_loader_seed_instance_service_port,
|
1465
|
+
load_config.remote_instance_weight_loader_send_weights_group_ports,
|
1462
1466
|
instance_ip,
|
1463
1467
|
),
|
1464
1468
|
)
|