sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/model_loader/loader.py
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  # ruff: noqa: SIM117
4
6
  import collections
5
7
  import concurrent
@@ -10,14 +12,29 @@ import json
10
12
  import logging
11
13
  import math
12
14
  import os
15
+ import re
16
+ import socket
17
+ import threading
13
18
  import time
14
19
  from abc import ABC, abstractmethod
15
20
  from concurrent.futures import ThreadPoolExecutor
16
21
  from contextlib import contextmanager
17
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Dict,
26
+ Generator,
27
+ Iterable,
28
+ List,
29
+ Optional,
30
+ Tuple,
31
+ cast,
32
+ )
33
+ from urllib.parse import urlparse
18
34
 
19
35
  import huggingface_hub
20
36
  import numpy as np
37
+ import requests
21
38
  import safetensors.torch
22
39
  import torch
23
40
  from huggingface_hub import HfApi, hf_hub_download
@@ -26,9 +43,7 @@ from tqdm.auto import tqdm
26
43
  from transformers import AutoModelForCausalLM
27
44
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
28
45
 
29
- from sglang.srt.configs.device_config import DeviceConfig
30
46
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
31
- from sglang.srt.configs.model_config import ModelConfig
32
47
  from sglang.srt.connector import (
33
48
  ConnectorType,
34
49
  create_remote_connector,
@@ -39,7 +54,6 @@ from sglang.srt.distributed import (
39
54
  get_tensor_model_parallel_rank,
40
55
  get_tensor_model_parallel_world_size,
41
56
  )
42
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
57
  from sglang.srt.model_loader.utils import (
44
58
  get_model_architecture,
45
59
  post_load_weights,
@@ -47,6 +61,7 @@ from sglang.srt.model_loader.utils import (
47
61
  )
48
62
  from sglang.srt.model_loader.weight_utils import (
49
63
  _BAR_FORMAT,
64
+ default_weight_loader,
50
65
  download_safetensors_index_file_from_hf,
51
66
  download_weights_from_hf,
52
67
  filter_duplicate_safetensors_files,
@@ -62,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
62
77
  safetensors_weights_iterator,
63
78
  set_runai_streamer_env,
64
79
  )
80
+ from sglang.srt.remote_instance_weight_loader_utils import (
81
+ trigger_transferring_weights_request,
82
+ )
65
83
  from sglang.srt.utils import (
66
84
  get_bool_env_var,
67
85
  get_device_capability,
@@ -70,6 +88,11 @@ from sglang.srt.utils import (
70
88
  set_weight_attrs,
71
89
  )
72
90
 
91
+ if TYPE_CHECKING:
92
+ from sglang.srt.configs.device_config import DeviceConfig
93
+ from sglang.srt.configs.model_config import ModelConfig
94
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
95
+
73
96
  _is_npu = is_npu()
74
97
 
75
98
 
@@ -1366,6 +1389,104 @@ class GGUFModelLoader(BaseModelLoader):
1366
1389
  return model
1367
1390
 
1368
1391
 
1392
+ class RemoteInstanceModelLoader(BaseModelLoader):
1393
+ """Model loader that can load Tensors from remote sglang instance."""
1394
+
1395
+ def __init__(self, load_config: LoadConfig):
1396
+ super().__init__(load_config)
1397
+ if load_config.model_loader_extra_config:
1398
+ raise ValueError(
1399
+ f"Model loader extra config is not supported for "
1400
+ f"load format {load_config.load_format}"
1401
+ )
1402
+
1403
+ def download_model(self, model_config: ModelConfig) -> None:
1404
+ raise NotImplementedError
1405
+
1406
+ def load_model(
1407
+ self,
1408
+ *,
1409
+ model_config: ModelConfig,
1410
+ device_config: DeviceConfig,
1411
+ ) -> nn.Module:
1412
+ logger.info("Loading weights from remote instance ...")
1413
+ load_config = self.load_config
1414
+
1415
+ assert load_config.load_format == LoadFormat.REMOTE_INSTANCE, (
1416
+ f"Model loader {self.load_config.load_format} is not supported for "
1417
+ f"load format {load_config.load_format}"
1418
+ )
1419
+
1420
+ model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
1421
+
1422
+ with set_default_torch_dtype(model_config.dtype):
1423
+ with torch.device(device_config.device):
1424
+ model = _initialize_model(model_config, self.load_config)
1425
+
1426
+ with create_remote_connector(model_weights, device_config.device) as client:
1427
+ connector_type = get_connector_type(client)
1428
+ if connector_type == ConnectorType.INSTANCE:
1429
+ self.load_model_from_remote_instance(
1430
+ model, client, model_config, device_config
1431
+ )
1432
+ else:
1433
+ raise ValueError(
1434
+ f"Unsupported connector type {connector_type} for "
1435
+ f"remote tensor model loading."
1436
+ )
1437
+ return model.eval()
1438
+
1439
+ def load_model_from_remote_instance(
1440
+ self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1441
+ ) -> nn.Module:
1442
+ instance_ip = socket.gethostbyname(socket.gethostname())
1443
+ start_build_group_tic = time.time()
1444
+ client.build_group(
1445
+ gpu_id=device_config.gpu_id,
1446
+ tp_rank=model_config.tp_rank,
1447
+ instance_ip=instance_ip,
1448
+ )
1449
+ torch.cuda.synchronize()
1450
+ end_build_group_tic = time.time()
1451
+ logger.debug(
1452
+ f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
1453
+ )
1454
+
1455
+ if model_config.tp_rank == 0:
1456
+ t = threading.Thread(
1457
+ target=trigger_transferring_weights_request,
1458
+ args=(
1459
+ model_config.remote_instance_weight_loader_seed_instance_ip,
1460
+ model_config.remote_instance_weight_loader_seed_instance_service_port,
1461
+ model_config.remote_instance_weight_loader_send_weights_group_ports,
1462
+ instance_ip,
1463
+ ),
1464
+ )
1465
+ t.start()
1466
+
1467
+ start_get_weights_tic = time.time()
1468
+ with set_default_torch_dtype(model_config.dtype):
1469
+ for _, tensor in model.named_parameters():
1470
+ torch.distributed.broadcast(
1471
+ tensor.data,
1472
+ src=0,
1473
+ group=client._model_update_group,
1474
+ )
1475
+ torch.cuda.synchronize()
1476
+
1477
+ if hasattr(model, "post_load_weights"):
1478
+ model.post_load_weights()
1479
+ end_get_weights_tic = time.time()
1480
+ logger.debug(
1481
+ f"finish getting all weights from remote instance, time used: {(end_get_weights_tic - start_get_weights_tic):.4f}s"
1482
+ )
1483
+ # destroy the process group after loading weights
1484
+ torch.distributed.distributed_c10d.destroy_process_group(
1485
+ client._model_update_group
1486
+ )
1487
+ torch.cuda.empty_cache()
1488
+
1489
+
1369
1490
  class RemoteModelLoader(BaseModelLoader):
1370
1491
  """Model loader that can load Tensors from remote database."""
1371
1492
 
@@ -1567,4 +1688,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1567
1688
  if load_config.load_format == LoadFormat.REMOTE:
1568
1689
  return RemoteModelLoader(load_config)
1569
1690
 
1691
+ if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
1692
+ return RemoteInstanceModelLoader(load_config)
1693
+
1570
1694
  return DefaultModelLoader(load_config)
@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
35
35
  from sglang.srt.configs.load_config import LoadConfig
36
36
  from sglang.srt.configs.model_config import ModelConfig
37
37
  from sglang.srt.distributed import get_tensor_model_parallel_rank
38
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank
38
39
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
39
40
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
40
41
  from sglang.srt.utils import print_warning_once
@@ -680,7 +681,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
680
681
  """Create a weight loader that shards the weights along the given axis"""
681
682
 
682
683
  def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
683
- tp_rank = get_tensor_model_parallel_rank()
684
+ tp_rank = get_attention_tp_rank()
684
685
 
685
686
  shard_size = param.data.shape[shard_axis]
686
687
  start_idx = tp_rank * shard_size