sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
# ruff: noqa: SIM117
|
4
6
|
import collections
|
5
7
|
import concurrent
|
@@ -14,7 +16,17 @@ import time
|
|
14
16
|
from abc import ABC, abstractmethod
|
15
17
|
from concurrent.futures import ThreadPoolExecutor
|
16
18
|
from contextlib import contextmanager
|
17
|
-
from typing import
|
19
|
+
from typing import (
|
20
|
+
TYPE_CHECKING,
|
21
|
+
Any,
|
22
|
+
Dict,
|
23
|
+
Generator,
|
24
|
+
Iterable,
|
25
|
+
List,
|
26
|
+
Optional,
|
27
|
+
Tuple,
|
28
|
+
cast,
|
29
|
+
)
|
18
30
|
|
19
31
|
import huggingface_hub
|
20
32
|
import numpy as np
|
@@ -26,9 +38,7 @@ from tqdm.auto import tqdm
|
|
26
38
|
from transformers import AutoModelForCausalLM
|
27
39
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
28
40
|
|
29
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
30
41
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
|
-
from sglang.srt.configs.model_config import ModelConfig
|
32
42
|
from sglang.srt.connector import (
|
33
43
|
ConnectorType,
|
34
44
|
create_remote_connector,
|
@@ -39,9 +49,9 @@ from sglang.srt.distributed import (
|
|
39
49
|
get_tensor_model_parallel_rank,
|
40
50
|
get_tensor_model_parallel_world_size,
|
41
51
|
)
|
42
|
-
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
52
|
from sglang.srt.model_loader.utils import (
|
44
53
|
get_model_architecture,
|
54
|
+
post_load_weights,
|
45
55
|
set_default_torch_dtype,
|
46
56
|
)
|
47
57
|
from sglang.srt.model_loader.weight_utils import (
|
@@ -69,6 +79,11 @@ from sglang.srt.utils import (
|
|
69
79
|
set_weight_attrs,
|
70
80
|
)
|
71
81
|
|
82
|
+
if TYPE_CHECKING:
|
83
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
84
|
+
from sglang.srt.configs.model_config import ModelConfig
|
85
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
86
|
+
|
72
87
|
_is_npu = is_npu()
|
73
88
|
|
74
89
|
|
@@ -600,18 +615,7 @@ class DummyModelLoader(BaseModelLoader):
|
|
600
615
|
# random values to the weights.
|
601
616
|
initialize_dummy_weights(model)
|
602
617
|
|
603
|
-
|
604
|
-
# 1. Initial weight loading.
|
605
|
-
# 2. Post-processing of weights, including assigning specific member variables.
|
606
|
-
# For `dummy_init`, only the second stage is required.
|
607
|
-
if hasattr(model, "post_load_weights"):
|
608
|
-
if (
|
609
|
-
model_config.hf_config.architectures[0]
|
610
|
-
== "DeepseekV3ForCausalLMNextN"
|
611
|
-
):
|
612
|
-
model.post_load_weights(is_nextn=True)
|
613
|
-
else:
|
614
|
-
model.post_load_weights()
|
618
|
+
post_load_weights(model, model_config)
|
615
619
|
|
616
620
|
return model.eval()
|
617
621
|
|
@@ -751,6 +755,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|
751
755
|
state_dict.pop(key)
|
752
756
|
if state_dict:
|
753
757
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
758
|
+
|
759
|
+
post_load_weights(model, model_config)
|
760
|
+
|
754
761
|
return model.eval()
|
755
762
|
|
756
763
|
@staticmethod
|
@@ -1421,18 +1428,16 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1421
1428
|
# ignore hidden files
|
1422
1429
|
if file_name.startswith("."):
|
1423
1430
|
continue
|
1424
|
-
if os.path.splitext(file_name)[1]
|
1425
|
-
".bin",
|
1426
|
-
".pt",
|
1427
|
-
".safetensors",
|
1428
|
-
):
|
1431
|
+
if os.path.splitext(file_name)[1] in (".json", ".py"):
|
1429
1432
|
file_path = os.path.join(root, file_name)
|
1430
1433
|
with open(file_path, encoding="utf-8") as file:
|
1431
1434
|
file_content = file.read()
|
1432
1435
|
f_key = f"{model_name}/files/{file_name}"
|
1433
1436
|
client.setstr(f_key, file_content)
|
1434
1437
|
|
1435
|
-
def _load_model_from_remote_kv(
|
1438
|
+
def _load_model_from_remote_kv(
|
1439
|
+
self, model: nn.Module, model_config: ModelConfig, client
|
1440
|
+
):
|
1436
1441
|
for _, module in model.named_modules():
|
1437
1442
|
quant_method = getattr(module, "quant_method", None)
|
1438
1443
|
if quant_method is not None:
|
@@ -1460,6 +1465,8 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1460
1465
|
if state_dict:
|
1461
1466
|
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1462
1467
|
|
1468
|
+
post_load_weights(model, model_config)
|
1469
|
+
|
1463
1470
|
def _load_model_from_remote_fs(
|
1464
1471
|
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1465
1472
|
) -> nn.Module:
|
@@ -1501,15 +1508,13 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1501
1508
|
with set_default_torch_dtype(model_config.dtype):
|
1502
1509
|
with torch.device(device_config.device):
|
1503
1510
|
model = _initialize_model(model_config, self.load_config)
|
1504
|
-
for _, module in model.named_modules():
|
1505
|
-
quant_method = getattr(module, "quant_method", None)
|
1506
|
-
if quant_method is not None:
|
1507
|
-
quant_method.process_weights_after_loading(module)
|
1508
1511
|
|
1509
|
-
with create_remote_connector(
|
1512
|
+
with create_remote_connector(
|
1513
|
+
model_weights, device=device_config.device
|
1514
|
+
) as client:
|
1510
1515
|
connector_type = get_connector_type(client)
|
1511
1516
|
if connector_type == ConnectorType.KV:
|
1512
|
-
self._load_model_from_remote_kv(model, client)
|
1517
|
+
self._load_model_from_remote_kv(model, model_config, client)
|
1513
1518
|
elif connector_type == ConnectorType.FS:
|
1514
1519
|
self._load_model_from_remote_fs(
|
1515
1520
|
model, client, model_config, device_config
|
sglang/srt/model_loader/utils.py
CHANGED
@@ -105,3 +105,15 @@ def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module],
|
|
105
105
|
|
106
106
|
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
107
107
|
return get_model_architecture(model_config)[1]
|
108
|
+
|
109
|
+
|
110
|
+
def post_load_weights(model: nn.Module, model_config: ModelConfig):
|
111
|
+
# Model weight loading consists of two stages:
|
112
|
+
# 1. Initial weight loading.
|
113
|
+
# 2. Post-processing of weights, including assigning specific member variables.
|
114
|
+
# For `dummy_init`, only the second stage is required.
|
115
|
+
if hasattr(model, "post_load_weights"):
|
116
|
+
if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN":
|
117
|
+
model.post_load_weights(is_nextn=True)
|
118
|
+
else:
|
119
|
+
model.post_load_weights()
|
@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
|
|
35
35
|
from sglang.srt.configs.load_config import LoadConfig
|
36
36
|
from sglang.srt.configs.model_config import ModelConfig
|
37
37
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
38
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
38
39
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
39
40
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
40
41
|
from sglang.srt.utils import print_warning_once
|
@@ -680,7 +681,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
|
680
681
|
"""Create a weight loader that shards the weights along the given axis"""
|
681
682
|
|
682
683
|
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
683
|
-
tp_rank =
|
684
|
+
tp_rank = get_attention_tp_rank()
|
684
685
|
|
685
686
|
shard_size = param.data.shape[shard_axis]
|
686
687
|
start_idx = tp_rank * shard_size
|