sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ import hashlib
|
|
8
8
|
import json
|
9
9
|
import logging
|
10
10
|
import os
|
11
|
-
import
|
11
|
+
import re
|
12
12
|
import tempfile
|
13
13
|
from collections import defaultdict
|
14
14
|
from typing import (
|
@@ -38,7 +38,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
38
38
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
39
39
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
40
40
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
41
|
-
from sglang.srt.utils import print_warning_once
|
41
|
+
from sglang.srt.utils import find_local_repo_dir, print_warning_once
|
42
|
+
from sglang.utils import is_in_ci
|
42
43
|
|
43
44
|
logger = logging.getLogger(__name__)
|
44
45
|
|
@@ -236,6 +237,149 @@ def get_quant_config(
|
|
236
237
|
return quant_cls.from_config(config)
|
237
238
|
|
238
239
|
|
240
|
+
def find_local_hf_snapshot_dir(
|
241
|
+
model_name_or_path: str,
|
242
|
+
cache_dir: Optional[str],
|
243
|
+
allow_patterns: List[str],
|
244
|
+
revision: Optional[str] = None,
|
245
|
+
) -> Optional[str]:
|
246
|
+
"""If the weights are already local, skip downloading and returns the path."""
|
247
|
+
if os.path.isdir(model_name_or_path):
|
248
|
+
return None
|
249
|
+
|
250
|
+
found_local_snapshot_dir = None
|
251
|
+
|
252
|
+
# Check custom cache_dir (if provided)
|
253
|
+
if cache_dir:
|
254
|
+
try:
|
255
|
+
repo_folder = os.path.join(
|
256
|
+
cache_dir,
|
257
|
+
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
258
|
+
["models", *model_name_or_path.split("/")]
|
259
|
+
),
|
260
|
+
)
|
261
|
+
rev_to_use = revision
|
262
|
+
if not rev_to_use:
|
263
|
+
ref_main = os.path.join(repo_folder, "refs", "main")
|
264
|
+
if os.path.isfile(ref_main):
|
265
|
+
with open(ref_main) as f:
|
266
|
+
rev_to_use = f.read().strip()
|
267
|
+
if rev_to_use:
|
268
|
+
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
|
269
|
+
if os.path.isdir(rev_dir):
|
270
|
+
found_local_snapshot_dir = rev_dir
|
271
|
+
except Exception as e:
|
272
|
+
logger.warning(
|
273
|
+
"Failed to find local snapshot in custom cache_dir %s: %s",
|
274
|
+
cache_dir,
|
275
|
+
e,
|
276
|
+
)
|
277
|
+
|
278
|
+
# Check default HF cache as well
|
279
|
+
if not found_local_snapshot_dir:
|
280
|
+
try:
|
281
|
+
rev_dir = find_local_repo_dir(model_name_or_path, revision)
|
282
|
+
if rev_dir and os.path.isdir(rev_dir):
|
283
|
+
found_local_snapshot_dir = rev_dir
|
284
|
+
except Exception as e:
|
285
|
+
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
286
|
+
|
287
|
+
# if any incomplete file exists, force re-download by returning None
|
288
|
+
if found_local_snapshot_dir:
|
289
|
+
repo_folder = os.path.abspath(
|
290
|
+
os.path.join(found_local_snapshot_dir, "..", "..")
|
291
|
+
)
|
292
|
+
blobs_dir = os.path.join(repo_folder, "blobs")
|
293
|
+
if os.path.isdir(blobs_dir) and glob.glob(
|
294
|
+
os.path.join(blobs_dir, "*.incomplete")
|
295
|
+
):
|
296
|
+
logger.info(
|
297
|
+
"Found .incomplete files in %s for %s. "
|
298
|
+
"Considering local snapshot incomplete.",
|
299
|
+
blobs_dir,
|
300
|
+
model_name_or_path,
|
301
|
+
)
|
302
|
+
return None
|
303
|
+
|
304
|
+
# if local snapshot exists, validate it contains at least one weight file
|
305
|
+
# matching allow_patterns before skipping download.
|
306
|
+
if found_local_snapshot_dir is None:
|
307
|
+
return None
|
308
|
+
|
309
|
+
local_weight_files: List[str] = []
|
310
|
+
try:
|
311
|
+
for pattern in allow_patterns:
|
312
|
+
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
313
|
+
for f in matched_files:
|
314
|
+
# os.path.exists returns False for broken symlinks.
|
315
|
+
if not os.path.exists(f):
|
316
|
+
continue
|
317
|
+
local_weight_files.append(f)
|
318
|
+
except Exception as e:
|
319
|
+
logger.warning(
|
320
|
+
"Failed to scan local snapshot %s with patterns %s: %s",
|
321
|
+
found_local_snapshot_dir,
|
322
|
+
allow_patterns,
|
323
|
+
e,
|
324
|
+
)
|
325
|
+
local_weight_files = []
|
326
|
+
|
327
|
+
# After we have a list of valid files, check for sharded model completeness.
|
328
|
+
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
|
329
|
+
checked_sharded_model = False
|
330
|
+
for f in local_weight_files:
|
331
|
+
if checked_sharded_model:
|
332
|
+
break
|
333
|
+
base_name = os.path.basename(f)
|
334
|
+
# Regex for files like model-00001-of-00009.safetensors
|
335
|
+
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
|
336
|
+
if match:
|
337
|
+
prefix = match.group(1)
|
338
|
+
shard_id_str = match.group(2)
|
339
|
+
total_shards_str = match.group(3)
|
340
|
+
suffix = match.group(4)
|
341
|
+
total_shards = int(total_shards_str)
|
342
|
+
|
343
|
+
# Check if all shards are present
|
344
|
+
missing_shards = []
|
345
|
+
for i in range(1, total_shards + 1):
|
346
|
+
# Reconstruct shard name, preserving padding of original shard id
|
347
|
+
shard_name = (
|
348
|
+
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
|
349
|
+
)
|
350
|
+
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
|
351
|
+
# os.path.exists returns False for broken symlinks, which is desired.
|
352
|
+
if not os.path.exists(expected_path):
|
353
|
+
missing_shards.append(shard_name)
|
354
|
+
|
355
|
+
if missing_shards:
|
356
|
+
logger.info(
|
357
|
+
"Found incomplete sharded model %s. Missing shards: %s. "
|
358
|
+
"Will attempt download.",
|
359
|
+
model_name_or_path,
|
360
|
+
missing_shards,
|
361
|
+
)
|
362
|
+
return None
|
363
|
+
|
364
|
+
# If we found and verified one set of shards, we are done.
|
365
|
+
checked_sharded_model = True
|
366
|
+
|
367
|
+
if len(local_weight_files) > 0:
|
368
|
+
logger.info(
|
369
|
+
"Found local HF snapshot for %s at %s; skipping download.",
|
370
|
+
model_name_or_path,
|
371
|
+
found_local_snapshot_dir,
|
372
|
+
)
|
373
|
+
return found_local_snapshot_dir
|
374
|
+
else:
|
375
|
+
logger.info(
|
376
|
+
"Local HF snapshot at %s has no files matching %s; will attempt download.",
|
377
|
+
found_local_snapshot_dir,
|
378
|
+
allow_patterns,
|
379
|
+
)
|
380
|
+
return None
|
381
|
+
|
382
|
+
|
239
383
|
def download_weights_from_hf(
|
240
384
|
model_name_or_path: str,
|
241
385
|
cache_dir: Optional[str],
|
@@ -260,6 +404,16 @@ def download_weights_from_hf(
|
|
260
404
|
Returns:
|
261
405
|
str: The path to the downloaded model weights.
|
262
406
|
"""
|
407
|
+
|
408
|
+
if is_in_ci():
|
409
|
+
# If the weights are already local, skip downloading and returns the path.
|
410
|
+
# This is used to skip too-many Huggingface API calls in CI.
|
411
|
+
path = find_local_hf_snapshot_dir(
|
412
|
+
model_name_or_path, cache_dir, allow_patterns, revision
|
413
|
+
)
|
414
|
+
if path is not None:
|
415
|
+
return path
|
416
|
+
|
263
417
|
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
264
418
|
# Before we download we look at that is available:
|
265
419
|
fs = HfFileSystem()
|
sglang/srt/models/bailing_moe.py
CHANGED
@@ -45,12 +45,12 @@ from sglang.srt.layers.dp_attention import (
|
|
45
45
|
get_attention_dp_size,
|
46
46
|
get_attention_tp_rank,
|
47
47
|
get_attention_tp_size,
|
48
|
+
is_dp_attention_enabled,
|
48
49
|
)
|
49
50
|
from sglang.srt.layers.layernorm import RMSNorm
|
50
51
|
from sglang.srt.layers.linear import (
|
51
52
|
MergedColumnParallelLinear,
|
52
53
|
QKVParallelLinear,
|
53
|
-
ReplicatedLinear,
|
54
54
|
RowParallelLinear,
|
55
55
|
)
|
56
56
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -72,6 +72,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
72
72
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
73
73
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
74
74
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
75
|
+
from sglang.srt.models.utils import (
|
76
|
+
create_fused_set_kv_buffer_arg,
|
77
|
+
enable_fused_set_kv_buffer,
|
78
|
+
)
|
75
79
|
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
|
76
80
|
|
77
81
|
LoraConfig = None
|
@@ -555,8 +559,27 @@ class BailingMoEAttention(nn.Module):
|
|
555
559
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
556
560
|
if self.use_qk_norm:
|
557
561
|
q, k = self._apply_qk_norm(q, k)
|
558
|
-
q, k = self.rotary_emb(
|
559
|
-
|
562
|
+
q, k = self.rotary_emb(
|
563
|
+
positions,
|
564
|
+
q,
|
565
|
+
k,
|
566
|
+
fused_set_kv_buffer_arg=(
|
567
|
+
create_fused_set_kv_buffer_arg(
|
568
|
+
value=v,
|
569
|
+
layer=self.attn,
|
570
|
+
forward_batch=forward_batch,
|
571
|
+
)
|
572
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
573
|
+
else None
|
574
|
+
),
|
575
|
+
)
|
576
|
+
context_layer = self.attn(
|
577
|
+
q,
|
578
|
+
k,
|
579
|
+
v,
|
580
|
+
forward_batch,
|
581
|
+
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
582
|
+
)
|
560
583
|
attn_output, _ = self.dense(context_layer)
|
561
584
|
return attn_output
|
562
585
|
|
@@ -702,7 +725,7 @@ class BailingMoEModel(nn.Module):
|
|
702
725
|
self.embed_dim,
|
703
726
|
quant_config=quant_config,
|
704
727
|
prefix=add_prefix("word_embeddings", prefix),
|
705
|
-
|
728
|
+
enable_tp=not is_dp_attention_enabled(),
|
706
729
|
)
|
707
730
|
else:
|
708
731
|
self.word_embeddings = PPMissingLayer()
|
@@ -33,11 +33,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
33
33
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
35
35
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
36
|
-
from sglang.srt.utils import BumpAllocator, add_prefix
|
36
|
+
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
|
37
37
|
|
38
38
|
logger = logging.getLogger(__name__)
|
39
39
|
|
40
40
|
|
41
|
+
_is_cuda = is_cuda()
|
42
|
+
|
43
|
+
|
41
44
|
class DeepseekModelNextN(nn.Module):
|
42
45
|
def __init__(
|
43
46
|
self,
|
@@ -66,12 +69,14 @@ class DeepseekModelNextN(nn.Module):
|
|
66
69
|
|
67
70
|
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
68
71
|
|
72
|
+
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
69
73
|
self.decoder = DeepseekV2DecoderLayer(
|
70
74
|
config,
|
71
75
|
0,
|
72
76
|
quant_config=quant_config,
|
73
77
|
is_nextn=True,
|
74
78
|
prefix=add_prefix("decoder", prefix),
|
79
|
+
alt_stream=self.alt_stream,
|
75
80
|
)
|
76
81
|
|
77
82
|
self.shared_head = nn.Module()
|