sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
# ruff: noqa: SIM117
|
4
6
|
import collections
|
5
7
|
import concurrent
|
@@ -10,14 +12,29 @@ import json
|
|
10
12
|
import logging
|
11
13
|
import math
|
12
14
|
import os
|
15
|
+
import re
|
16
|
+
import socket
|
17
|
+
import threading
|
13
18
|
import time
|
14
19
|
from abc import ABC, abstractmethod
|
15
20
|
from concurrent.futures import ThreadPoolExecutor
|
16
21
|
from contextlib import contextmanager
|
17
|
-
from typing import
|
22
|
+
from typing import (
|
23
|
+
TYPE_CHECKING,
|
24
|
+
Any,
|
25
|
+
Dict,
|
26
|
+
Generator,
|
27
|
+
Iterable,
|
28
|
+
List,
|
29
|
+
Optional,
|
30
|
+
Tuple,
|
31
|
+
cast,
|
32
|
+
)
|
33
|
+
from urllib.parse import urlparse
|
18
34
|
|
19
35
|
import huggingface_hub
|
20
36
|
import numpy as np
|
37
|
+
import requests
|
21
38
|
import safetensors.torch
|
22
39
|
import torch
|
23
40
|
from huggingface_hub import HfApi, hf_hub_download
|
@@ -26,9 +43,7 @@ from tqdm.auto import tqdm
|
|
26
43
|
from transformers import AutoModelForCausalLM
|
27
44
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
28
45
|
|
29
|
-
from sglang.srt.configs.device_config import DeviceConfig
|
30
46
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
31
|
-
from sglang.srt.configs.model_config import ModelConfig
|
32
47
|
from sglang.srt.connector import (
|
33
48
|
ConnectorType,
|
34
49
|
create_remote_connector,
|
@@ -39,7 +54,6 @@ from sglang.srt.distributed import (
|
|
39
54
|
get_tensor_model_parallel_rank,
|
40
55
|
get_tensor_model_parallel_world_size,
|
41
56
|
)
|
42
|
-
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
57
|
from sglang.srt.model_loader.utils import (
|
44
58
|
get_model_architecture,
|
45
59
|
post_load_weights,
|
@@ -47,6 +61,7 @@ from sglang.srt.model_loader.utils import (
|
|
47
61
|
)
|
48
62
|
from sglang.srt.model_loader.weight_utils import (
|
49
63
|
_BAR_FORMAT,
|
64
|
+
default_weight_loader,
|
50
65
|
download_safetensors_index_file_from_hf,
|
51
66
|
download_weights_from_hf,
|
52
67
|
filter_duplicate_safetensors_files,
|
@@ -62,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
|
|
62
77
|
safetensors_weights_iterator,
|
63
78
|
set_runai_streamer_env,
|
64
79
|
)
|
80
|
+
from sglang.srt.remote_instance_weight_loader_utils import (
|
81
|
+
trigger_transferring_weights_request,
|
82
|
+
)
|
65
83
|
from sglang.srt.utils import (
|
66
84
|
get_bool_env_var,
|
67
85
|
get_device_capability,
|
@@ -70,6 +88,11 @@ from sglang.srt.utils import (
|
|
70
88
|
set_weight_attrs,
|
71
89
|
)
|
72
90
|
|
91
|
+
if TYPE_CHECKING:
|
92
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
93
|
+
from sglang.srt.configs.model_config import ModelConfig
|
94
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
95
|
+
|
73
96
|
_is_npu = is_npu()
|
74
97
|
|
75
98
|
|
@@ -1366,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1366
1389
|
return model
|
1367
1390
|
|
1368
1391
|
|
1392
|
+
class RemoteInstanceModelLoader(BaseModelLoader):
|
1393
|
+
"""Model loader that can load Tensors from remote sglang instance."""
|
1394
|
+
|
1395
|
+
def __init__(self, load_config: LoadConfig):
|
1396
|
+
super().__init__(load_config)
|
1397
|
+
if load_config.model_loader_extra_config:
|
1398
|
+
raise ValueError(
|
1399
|
+
f"Model loader extra config is not supported for "
|
1400
|
+
f"load format {load_config.load_format}"
|
1401
|
+
)
|
1402
|
+
|
1403
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
1404
|
+
raise NotImplementedError
|
1405
|
+
|
1406
|
+
def load_model(
|
1407
|
+
self,
|
1408
|
+
*,
|
1409
|
+
model_config: ModelConfig,
|
1410
|
+
device_config: DeviceConfig,
|
1411
|
+
) -> nn.Module:
|
1412
|
+
logger.info("Loading weights from remote instance ...")
|
1413
|
+
load_config = self.load_config
|
1414
|
+
|
1415
|
+
assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
|
1416
|
+
f"Model loader {self.load_config.load_format} is not supported for "
|
1417
|
+
f"load format {load_config.load_format}"
|
1418
|
+
)
|
1419
|
+
|
1420
|
+
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
|
1421
|
+
|
1422
|
+
with set_default_torch_dtype(model_config.dtype):
|
1423
|
+
with torch.device(device_config.device):
|
1424
|
+
model = _initialize_model(model_config, self.load_config)
|
1425
|
+
|
1426
|
+
with create_remote_connector(model_weights, device_config.device) as client:
|
1427
|
+
connector_type = get_connector_type(client)
|
1428
|
+
if connector_type == ConnectorType.INSTANCE:
|
1429
|
+
self.load_model_from_remote_instance(
|
1430
|
+
model, client, model_config, device_config
|
1431
|
+
)
|
1432
|
+
else:
|
1433
|
+
raise ValueError(
|
1434
|
+
f"Unsupported connector type {connector_type} for "
|
1435
|
+
f"remote tensor model loading."
|
1436
|
+
)
|
1437
|
+
return model.eval()
|
1438
|
+
|
1439
|
+
def load_model_from_remote_instance(
|
1440
|
+
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1441
|
+
) -> nn.Module:
|
1442
|
+
instance_ip = socket.gethostbyname(socket.gethostname())
|
1443
|
+
start_build_group_tic = time.time()
|
1444
|
+
client.build_group(
|
1445
|
+
gpu_id=device_config.gpu_id,
|
1446
|
+
tp_rank=model_config.tp_rank,
|
1447
|
+
instance_ip=instance_ip,
|
1448
|
+
)
|
1449
|
+
torch.cuda.synchronize()
|
1450
|
+
end_build_group_tic = time.time()
|
1451
|
+
logger.debug(
|
1452
|
+
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
|
1453
|
+
)
|
1454
|
+
|
1455
|
+
if model_config.tp_rank == 0:
|
1456
|
+
t = threading.Thread(
|
1457
|
+
target=trigger_transferring_weights_request,
|
1458
|
+
args=(
|
1459
|
+
model_config.remote_instance_weight_loader_seed_instance_ip,
|
1460
|
+
model_config.remote_instance_weight_loader_seed_instance_service_port,
|
1461
|
+
model_config.remote_instance_weight_loader_send_weights_group_ports,
|
1462
|
+
instance_ip,
|
1463
|
+
),
|
1464
|
+
)
|
1465
|
+
t.start()
|
1466
|
+
|
1467
|
+
start_get_weights_tic = time.time()
|
1468
|
+
with set_default_torch_dtype(model_config.dtype):
|
1469
|
+
for _, tensor in model.named_parameters():
|
1470
|
+
torch.distributed.broadcast(
|
1471
|
+
tensor.data,
|
1472
|
+
src=0,
|
1473
|
+
group=client._model_update_group,
|
1474
|
+
)
|
1475
|
+
torch.cuda.synchronize()
|
1476
|
+
|
1477
|
+
if hasattr(model, "post_load_weights"):
|
1478
|
+
model.post_load_weights()
|
1479
|
+
end_get_weights_tic = time.time()
|
1480
|
+
logger.debug(
|
1481
|
+
f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
|
1482
|
+
)
|
1483
|
+
# destroy the process group after loading weights
|
1484
|
+
torch.distributed.distributed_c10d.destroy_process_group(
|
1485
|
+
client._model_update_group
|
1486
|
+
)
|
1487
|
+
torch.cuda.empty_cache()
|
1488
|
+
|
1489
|
+
|
1369
1490
|
class RemoteModelLoader(BaseModelLoader):
|
1370
1491
|
"""Model loader that can load Tensors from remote database."""
|
1371
1492
|
|
@@ -1567,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1567
1688
|
if load_config.load_format == LoadFormat.REMOTE:
|
1568
1689
|
return RemoteModelLoader(load_config)
|
1569
1690
|
|
1691
|
+
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
|
1692
|
+
return RemoteInstanceModelLoader(load_config)
|
1693
|
+
|
1570
1694
|
return DefaultModelLoader(load_config)
|
@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
|
|
35
35
|
from sglang.srt.configs.load_config import LoadConfig
|
36
36
|
from sglang.srt.configs.model_config import ModelConfig
|
37
37
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
38
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
38
39
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
39
40
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
40
41
|
from sglang.srt.utils import print_warning_once
|
@@ -680,7 +681,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
|
680
681
|
"""Create a weight loader that shards the weights along the given axis"""
|
681
682
|
|
682
683
|
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
683
|
-
tp_rank =
|
684
|
+
tp_rank = get_attention_tp_rank()
|
684
685
|
|
685
686
|
shard_size = param.data.shape[shard_axis]
|
686
687
|
start_idx = tp_rank * shard_size
|