sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -19,8 +19,10 @@ import logging
|
|
19
19
|
import threading
|
20
20
|
from typing import TYPE_CHECKING, Optional, Union
|
21
21
|
|
22
|
+
import numpy as np
|
22
23
|
import torch
|
23
24
|
|
25
|
+
from sglang.srt.configs.model_config import AttentionArch
|
24
26
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
25
27
|
|
26
28
|
logger = logging.getLogger(__name__)
|
@@ -73,11 +75,16 @@ class NPUGraphRunner(CudaGraphRunner):
|
|
73
75
|
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
74
76
|
|
75
77
|
# Replay
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
78
|
+
if self.model_runner.model_config.index_head_dim is None:
|
79
|
+
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
|
80
|
+
self.bs - self.raw_bs
|
81
|
+
)
|
82
|
+
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
|
83
|
+
thread.start()
|
84
|
+
self.graphs[self.bs].replay()
|
85
|
+
thread.join()
|
86
|
+
else:
|
87
|
+
self.graphs[self.bs].replay()
|
81
88
|
|
82
89
|
output = self.output_buffers[self.bs]
|
83
90
|
if isinstance(output, LogitsProcessorOutput):
|
@@ -1,16 +1,22 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
3
7
|
from torch import nn
|
4
8
|
|
5
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
6
|
-
from sglang.srt.configs.load_config import LoadConfig
|
7
|
-
from sglang.srt.configs.model_config import ModelConfig
|
8
9
|
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
9
10
|
from sglang.srt.model_loader.utils import (
|
10
11
|
get_architecture_class_name,
|
11
12
|
get_model_architecture,
|
12
13
|
)
|
13
14
|
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
17
|
+
from sglang.srt.configs.load_config import LoadConfig
|
18
|
+
from sglang.srt.configs.model_config import ModelConfig
|
19
|
+
|
14
20
|
|
15
21
|
def get_model(
|
16
22
|
*,
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
# ruff: noqa: SIM117
|
4
6
|
import collections
|
5
7
|
import concurrent
|
@@ -10,14 +12,29 @@ import json
|
|
10
12
|
import logging
|
11
13
|
import math
|
12
14
|
import os
|
15
|
+
import re
|
16
|
+
import socket
|
17
|
+
import threading
|
13
18
|
import time
|
14
19
|
from abc import ABC, abstractmethod
|
15
20
|
from concurrent.futures import ThreadPoolExecutor
|
16
21
|
from contextlib import contextmanager
|
17
|
-
from typing import
|
22
|
+
from typing import (
|
23
|
+
TYPE_CHECKING,
|
24
|
+
Any,
|
25
|
+
Dict,
|
26
|
+
Generator,
|
27
|
+
Iterable,
|
28
|
+
List,
|
29
|
+
Optional,
|
30
|
+
Tuple,
|
31
|
+
cast,
|
32
|
+
)
|
33
|
+
from urllib.parse import urlparse
|
18
34
|
|
19
35
|
import huggingface_hub
|
20
36
|
import numpy as np
|
37
|
+
import requests
|
21
38
|
import safetensors.torch
|
22
39
|
import torch
|
23
40
|
from huggingface_hub import HfApi, hf_hub_download
|
@@ -26,9 +43,7 @@ from tqdm.auto import tqdm
|
|
26
43
|
from transformers import AutoModelForCausalLM
|
27
44
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
28
45
|
|
29
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
30
46
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
|
-
from sglang.srt.configs.model_config import ModelConfig
|
32
47
|
from sglang.srt.connector import (
|
33
48
|
ConnectorType,
|
34
49
|
create_remote_connector,
|
@@ -39,7 +54,9 @@ from sglang.srt.distributed import (
|
|
39
54
|
get_tensor_model_parallel_rank,
|
40
55
|
get_tensor_model_parallel_world_size,
|
41
56
|
)
|
42
|
-
from sglang.srt.
|
57
|
+
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
58
|
+
trigger_transferring_weights_request,
|
59
|
+
)
|
43
60
|
from sglang.srt.model_loader.utils import (
|
44
61
|
get_model_architecture,
|
45
62
|
post_load_weights,
|
@@ -47,6 +64,7 @@ from sglang.srt.model_loader.utils import (
|
|
47
64
|
)
|
48
65
|
from sglang.srt.model_loader.weight_utils import (
|
49
66
|
_BAR_FORMAT,
|
67
|
+
default_weight_loader,
|
50
68
|
download_safetensors_index_file_from_hf,
|
51
69
|
download_weights_from_hf,
|
52
70
|
filter_duplicate_safetensors_files,
|
@@ -70,6 +88,11 @@ from sglang.srt.utils import (
|
|
70
88
|
set_weight_attrs,
|
71
89
|
)
|
72
90
|
|
91
|
+
if TYPE_CHECKING:
|
92
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
93
|
+
from sglang.srt.configs.model_config import ModelConfig
|
94
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
95
|
+
|
73
96
|
_is_npu = is_npu()
|
74
97
|
|
75
98
|
|
@@ -183,7 +206,10 @@ def _initialize_model(
|
|
183
206
|
if _is_npu:
|
184
207
|
packed_modules_mapping.update(
|
185
208
|
{
|
186
|
-
"visual": {
|
209
|
+
"visual": {
|
210
|
+
"qkv_proj": ["qkv"],
|
211
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
212
|
+
},
|
187
213
|
"vision_model": {
|
188
214
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
189
215
|
"proj": ["out_proj"],
|
@@ -1366,6 +1392,105 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1366
1392
|
return model
|
1367
1393
|
|
1368
1394
|
|
1395
|
+
class RemoteInstanceModelLoader(BaseModelLoader):
|
1396
|
+
"""Model loader that can load Tensors from remote sglang instance."""
|
1397
|
+
|
1398
|
+
def __init__(self, load_config: LoadConfig):
|
1399
|
+
super().__init__(load_config)
|
1400
|
+
if load_config.model_loader_extra_config:
|
1401
|
+
raise ValueError(
|
1402
|
+
f"Model loader extra config is not supported for "
|
1403
|
+
f"load format {load_config.load_format}"
|
1404
|
+
)
|
1405
|
+
|
1406
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
1407
|
+
raise NotImplementedError
|
1408
|
+
|
1409
|
+
def load_model(
|
1410
|
+
self,
|
1411
|
+
*,
|
1412
|
+
model_config: ModelConfig,
|
1413
|
+
device_config: DeviceConfig,
|
1414
|
+
) -> nn.Module:
|
1415
|
+
logger.info("Loading weights from remote instance ...")
|
1416
|
+
load_config = self.load_config
|
1417
|
+
|
1418
|
+
assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
|
1419
|
+
f"Model loader {self.load_config.load_format} is not supported for "
|
1420
|
+
f"load format {load_config.load_format}"
|
1421
|
+
)
|
1422
|
+
|
1423
|
+
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
|
1424
|
+
|
1425
|
+
with set_default_torch_dtype(model_config.dtype):
|
1426
|
+
with torch.device(device_config.device):
|
1427
|
+
model = _initialize_model(model_config, self.load_config)
|
1428
|
+
|
1429
|
+
with create_remote_connector(model_weights, device_config.device) as client:
|
1430
|
+
connector_type = get_connector_type(client)
|
1431
|
+
if connector_type == ConnectorType.INSTANCE:
|
1432
|
+
self.load_model_from_remote_instance(
|
1433
|
+
model, client, model_config, device_config
|
1434
|
+
)
|
1435
|
+
else:
|
1436
|
+
raise ValueError(
|
1437
|
+
f"Unsupported connector type {connector_type} for "
|
1438
|
+
f"remote tensor model loading."
|
1439
|
+
)
|
1440
|
+
return model.eval()
|
1441
|
+
|
1442
|
+
def load_model_from_remote_instance(
|
1443
|
+
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1444
|
+
) -> nn.Module:
|
1445
|
+
load_config = self.load_config
|
1446
|
+
instance_ip = socket.gethostbyname(socket.gethostname())
|
1447
|
+
start_build_group_tic = time.time()
|
1448
|
+
client.build_group(
|
1449
|
+
gpu_id=device_config.gpu_id,
|
1450
|
+
tp_rank=load_config.tp_rank,
|
1451
|
+
instance_ip=instance_ip,
|
1452
|
+
)
|
1453
|
+
torch.cuda.synchronize()
|
1454
|
+
end_build_group_tic = time.time()
|
1455
|
+
logger.debug(
|
1456
|
+
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
1457
|
+
)
|
1458
|
+
|
1459
|
+
if load_config.tp_rank == 0:
|
1460
|
+
t = threading.Thread(
|
1461
|
+
target=trigger_transferring_weights_request,
|
1462
|
+
args=(
|
1463
|
+
load_config.remote_instance_weight_loader_seed_instance_ip,
|
1464
|
+
load_config.remote_instance_weight_loader_seed_instance_service_port,
|
1465
|
+
load_config.remote_instance_weight_loader_send_weights_group_ports,
|
1466
|
+
instance_ip,
|
1467
|
+
),
|
1468
|
+
)
|
1469
|
+
t.start()
|
1470
|
+
|
1471
|
+
start_get_weights_tic = time.time()
|
1472
|
+
with set_default_torch_dtype(model_config.dtype):
|
1473
|
+
for _, tensor in model.named_parameters():
|
1474
|
+
torch.distributed.broadcast(
|
1475
|
+
tensor.data,
|
1476
|
+
src=0,
|
1477
|
+
group=client._model_update_group,
|
1478
|
+
)
|
1479
|
+
torch.cuda.synchronize()
|
1480
|
+
|
1481
|
+
if hasattr(model, "post_load_weights"):
|
1482
|
+
model.post_load_weights()
|
1483
|
+
end_get_weights_tic = time.time()
|
1484
|
+
logger.debug(
|
1485
|
+
f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
|
1486
|
+
)
|
1487
|
+
# destroy the process group after loading weights
|
1488
|
+
torch.distributed.distributed_c10d.destroy_process_group(
|
1489
|
+
client._model_update_group
|
1490
|
+
)
|
1491
|
+
torch.cuda.empty_cache()
|
1492
|
+
|
1493
|
+
|
1369
1494
|
class RemoteModelLoader(BaseModelLoader):
|
1370
1495
|
"""Model loader that can load Tensors from remote database."""
|
1371
1496
|
|
@@ -1567,4 +1692,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1567
1692
|
if load_config.load_format == LoadFormat.REMOTE:
|
1568
1693
|
return RemoteModelLoader(load_config)
|
1569
1694
|
|
1695
|
+
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
|
1696
|
+
return RemoteInstanceModelLoader(load_config)
|
1697
|
+
|
1570
1698
|
return DefaultModelLoader(load_config)
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import List
|
5
|
+
|
6
|
+
import requests
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
def trigger_init_weights_send_group_for_remote_instance_request(
|
12
|
+
remote_instance_weight_loader_seed_instance_ip: str,
|
13
|
+
remote_instance_weight_loader_seed_instance_service_port: int,
|
14
|
+
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
15
|
+
remote_instance_weight_loader_client_id: str,
|
16
|
+
):
|
17
|
+
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
18
|
+
# Only support loading weights from instance with same parallelism strategy.
|
19
|
+
# Per TP rank pair between seed and dst instances will build a communication group for sending weights.
|
20
|
+
# i.e. seed TP 0 <-> dst TP 0, seed TP 1 <-> dst TP 1, etc.
|
21
|
+
# Each communication group will have a world size 2.
|
22
|
+
try:
|
23
|
+
requests.post(
|
24
|
+
f"{seed_instance_service_url}/init_weights_send_group_for_remote_instance",
|
25
|
+
json={
|
26
|
+
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
27
|
+
"ports": (
|
28
|
+
",".join(
|
29
|
+
str(p)
|
30
|
+
for p in remote_instance_weight_loader_send_weights_group_ports
|
31
|
+
)
|
32
|
+
),
|
33
|
+
"group_rank": 0,
|
34
|
+
"world_size": 2,
|
35
|
+
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
36
|
+
"backend": "nccl",
|
37
|
+
},
|
38
|
+
)
|
39
|
+
except Exception as e:
|
40
|
+
logger.error(
|
41
|
+
f"Failed to trigger init_weights_send_group_for_remote_instance_request to seed instance {seed_instance_service_url}: {e}."
|
42
|
+
)
|
43
|
+
raise
|
44
|
+
|
45
|
+
|
46
|
+
def trigger_transferring_weights_request(
|
47
|
+
remote_instance_weight_loader_seed_instance_ip: str,
|
48
|
+
remote_instance_weight_loader_seed_instance_service_port: int,
|
49
|
+
remote_instance_weight_loader_send_weights_group_ports: List[int],
|
50
|
+
remote_instance_weight_loader_client_id: str,
|
51
|
+
):
|
52
|
+
seed_instance_service_url = f"http://{remote_instance_weight_loader_seed_instance_ip}:{remote_instance_weight_loader_seed_instance_service_port}"
|
53
|
+
try:
|
54
|
+
requests.post(
|
55
|
+
f"{seed_instance_service_url}/send_weights_to_remote_instance",
|
56
|
+
json={
|
57
|
+
"master_address": remote_instance_weight_loader_seed_instance_ip,
|
58
|
+
"ports": (
|
59
|
+
",".join(
|
60
|
+
str(p)
|
61
|
+
for p in remote_instance_weight_loader_send_weights_group_ports
|
62
|
+
)
|
63
|
+
),
|
64
|
+
"group_name": f"send_weights_{remote_instance_weight_loader_client_id}",
|
65
|
+
},
|
66
|
+
)
|
67
|
+
except Exception as e:
|
68
|
+
logger.error(f"Failed to trigger send weights to remote instance request: {e}")
|
69
|
+
raise
|
@@ -8,7 +8,7 @@ import hashlib
|
|
8
8
|
import json
|
9
9
|
import logging
|
10
10
|
import os
|
11
|
-
import
|
11
|
+
import re
|
12
12
|
import tempfile
|
13
13
|
from collections import defaultdict
|
14
14
|
from typing import (
|
@@ -35,9 +35,11 @@ from tqdm.auto import tqdm
|
|
35
35
|
from sglang.srt.configs.load_config import LoadConfig
|
36
36
|
from sglang.srt.configs.model_config import ModelConfig
|
37
37
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
38
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
38
39
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
39
40
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
40
|
-
from sglang.srt.utils import print_warning_once
|
41
|
+
from sglang.srt.utils import find_local_repo_dir, print_warning_once
|
42
|
+
from sglang.utils import is_in_ci
|
41
43
|
|
42
44
|
logger = logging.getLogger(__name__)
|
43
45
|
|
@@ -235,6 +237,149 @@ def get_quant_config(
|
|
235
237
|
return quant_cls.from_config(config)
|
236
238
|
|
237
239
|
|
240
|
+
def find_local_hf_snapshot_dir(
|
241
|
+
model_name_or_path: str,
|
242
|
+
cache_dir: Optional[str],
|
243
|
+
allow_patterns: List[str],
|
244
|
+
revision: Optional[str] = None,
|
245
|
+
) -> Optional[str]:
|
246
|
+
"""If the weights are already local, skip downloading and returns the path."""
|
247
|
+
if os.path.isdir(model_name_or_path):
|
248
|
+
return None
|
249
|
+
|
250
|
+
found_local_snapshot_dir = None
|
251
|
+
|
252
|
+
# Check custom cache_dir (if provided)
|
253
|
+
if cache_dir:
|
254
|
+
try:
|
255
|
+
repo_folder = os.path.join(
|
256
|
+
cache_dir,
|
257
|
+
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
258
|
+
["models", *model_name_or_path.split("/")]
|
259
|
+
),
|
260
|
+
)
|
261
|
+
rev_to_use = revision
|
262
|
+
if not rev_to_use:
|
263
|
+
ref_main = os.path.join(repo_folder, "refs", "main")
|
264
|
+
if os.path.isfile(ref_main):
|
265
|
+
with open(ref_main) as f:
|
266
|
+
rev_to_use = f.read().strip()
|
267
|
+
if rev_to_use:
|
268
|
+
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
|
269
|
+
if os.path.isdir(rev_dir):
|
270
|
+
found_local_snapshot_dir = rev_dir
|
271
|
+
except Exception as e:
|
272
|
+
logger.warning(
|
273
|
+
"Failed to find local snapshot in custom cache_dir %s: %s",
|
274
|
+
cache_dir,
|
275
|
+
e,
|
276
|
+
)
|
277
|
+
|
278
|
+
# Check default HF cache as well
|
279
|
+
if not found_local_snapshot_dir:
|
280
|
+
try:
|
281
|
+
rev_dir = find_local_repo_dir(model_name_or_path, revision)
|
282
|
+
if rev_dir and os.path.isdir(rev_dir):
|
283
|
+
found_local_snapshot_dir = rev_dir
|
284
|
+
except Exception as e:
|
285
|
+
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
|
286
|
+
|
287
|
+
# if any incomplete file exists, force re-download by returning None
|
288
|
+
if found_local_snapshot_dir:
|
289
|
+
repo_folder = os.path.abspath(
|
290
|
+
os.path.join(found_local_snapshot_dir, "..", "..")
|
291
|
+
)
|
292
|
+
blobs_dir = os.path.join(repo_folder, "blobs")
|
293
|
+
if os.path.isdir(blobs_dir) and glob.glob(
|
294
|
+
os.path.join(blobs_dir, "*.incomplete")
|
295
|
+
):
|
296
|
+
logger.info(
|
297
|
+
"Found .incomplete files in %s for %s. "
|
298
|
+
"Considering local snapshot incomplete.",
|
299
|
+
blobs_dir,
|
300
|
+
model_name_or_path,
|
301
|
+
)
|
302
|
+
return None
|
303
|
+
|
304
|
+
# if local snapshot exists, validate it contains at least one weight file
|
305
|
+
# matching allow_patterns before skipping download.
|
306
|
+
if found_local_snapshot_dir is None:
|
307
|
+
return None
|
308
|
+
|
309
|
+
local_weight_files: List[str] = []
|
310
|
+
try:
|
311
|
+
for pattern in allow_patterns:
|
312
|
+
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
|
313
|
+
for f in matched_files:
|
314
|
+
# os.path.exists returns False for broken symlinks.
|
315
|
+
if not os.path.exists(f):
|
316
|
+
continue
|
317
|
+
local_weight_files.append(f)
|
318
|
+
except Exception as e:
|
319
|
+
logger.warning(
|
320
|
+
"Failed to scan local snapshot %s with patterns %s: %s",
|
321
|
+
found_local_snapshot_dir,
|
322
|
+
allow_patterns,
|
323
|
+
e,
|
324
|
+
)
|
325
|
+
local_weight_files = []
|
326
|
+
|
327
|
+
# After we have a list of valid files, check for sharded model completeness.
|
328
|
+
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
|
329
|
+
checked_sharded_model = False
|
330
|
+
for f in local_weight_files:
|
331
|
+
if checked_sharded_model:
|
332
|
+
break
|
333
|
+
base_name = os.path.basename(f)
|
334
|
+
# Regex for files like model-00001-of-00009.safetensors
|
335
|
+
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
|
336
|
+
if match:
|
337
|
+
prefix = match.group(1)
|
338
|
+
shard_id_str = match.group(2)
|
339
|
+
total_shards_str = match.group(3)
|
340
|
+
suffix = match.group(4)
|
341
|
+
total_shards = int(total_shards_str)
|
342
|
+
|
343
|
+
# Check if all shards are present
|
344
|
+
missing_shards = []
|
345
|
+
for i in range(1, total_shards + 1):
|
346
|
+
# Reconstruct shard name, preserving padding of original shard id
|
347
|
+
shard_name = (
|
348
|
+
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
|
349
|
+
)
|
350
|
+
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
|
351
|
+
# os.path.exists returns False for broken symlinks, which is desired.
|
352
|
+
if not os.path.exists(expected_path):
|
353
|
+
missing_shards.append(shard_name)
|
354
|
+
|
355
|
+
if missing_shards:
|
356
|
+
logger.info(
|
357
|
+
"Found incomplete sharded model %s. Missing shards: %s. "
|
358
|
+
"Will attempt download.",
|
359
|
+
model_name_or_path,
|
360
|
+
missing_shards,
|
361
|
+
)
|
362
|
+
return None
|
363
|
+
|
364
|
+
# If we found and verified one set of shards, we are done.
|
365
|
+
checked_sharded_model = True
|
366
|
+
|
367
|
+
if len(local_weight_files) > 0:
|
368
|
+
logger.info(
|
369
|
+
"Found local HF snapshot for %s at %s; skipping download.",
|
370
|
+
model_name_or_path,
|
371
|
+
found_local_snapshot_dir,
|
372
|
+
)
|
373
|
+
return found_local_snapshot_dir
|
374
|
+
else:
|
375
|
+
logger.info(
|
376
|
+
"Local HF snapshot at %s has no files matching %s; will attempt download.",
|
377
|
+
found_local_snapshot_dir,
|
378
|
+
allow_patterns,
|
379
|
+
)
|
380
|
+
return None
|
381
|
+
|
382
|
+
|
238
383
|
def download_weights_from_hf(
|
239
384
|
model_name_or_path: str,
|
240
385
|
cache_dir: Optional[str],
|
@@ -259,6 +404,16 @@ def download_weights_from_hf(
|
|
259
404
|
Returns:
|
260
405
|
str: The path to the downloaded model weights.
|
261
406
|
"""
|
407
|
+
|
408
|
+
if is_in_ci():
|
409
|
+
# If the weights are already local, skip downloading and returns the path.
|
410
|
+
# This is used to skip too-many Huggingface API calls in CI.
|
411
|
+
path = find_local_hf_snapshot_dir(
|
412
|
+
model_name_or_path, cache_dir, allow_patterns, revision
|
413
|
+
)
|
414
|
+
if path is not None:
|
415
|
+
return path
|
416
|
+
|
262
417
|
if not huggingface_hub.constants.HF_HUB_OFFLINE:
|
263
418
|
# Before we download we look at that is available:
|
264
419
|
fs = HfFileSystem()
|
@@ -680,7 +835,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
|
680
835
|
"""Create a weight loader that shards the weights along the given axis"""
|
681
836
|
|
682
837
|
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
683
|
-
tp_rank =
|
838
|
+
tp_rank = get_attention_tp_rank()
|
684
839
|
|
685
840
|
shard_size = param.data.shape[shard_axis]
|
686
841
|
start_idx = tp_rank * shard_size
|