sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +18 -3
- sglang/compile_deep_gemm.py +13 -7
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +25 -2
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -5
- sglang/srt/entrypoints/engine.py +13 -5
- sglang/srt/entrypoints/http_server.py +22 -3
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +7 -0
- sglang/srt/eplb/expert_distribution.py +34 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +7 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
- sglang/srt/layers/communicator.py +23 -1
- sglang/srt/layers/layernorm.py +16 -2
- sglang/srt/layers/logits_processor.py +4 -20
- sglang/srt/layers/moe/ep_moe/layer.py +0 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
- sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
- sglang/srt/layers/moe/topk.py +31 -6
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +9 -78
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/rotary_embedding.py +117 -45
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +26 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +164 -129
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +154 -59
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +171 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +11 -11
- sglang/srt/model_executor/model_runner.py +76 -21
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +149 -34
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +0 -1
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +1 -1
- sglang/srt/models/qwen3_moe.py +16 -8
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +103 -22
- sglang/srt/single_batch_overlap.py +4 -1
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +55 -32
- sglang/srt/utils/hf_transformers_utils.py +38 -16
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Usage:
|
|
3
|
+
1) Launch the server with wait-for-initial-weights option in one terminal:
|
|
4
|
+
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
|
|
5
|
+
|
|
6
|
+
2) Torchrun this script in another terminal:
|
|
7
|
+
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
|
8
|
+
|
|
9
|
+
Or use the integrated entry point:
|
|
10
|
+
python -m sglang.srt.checkpoint_engine.update --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import argparse
|
|
14
|
+
import json
|
|
15
|
+
import os
|
|
16
|
+
import pickle
|
|
17
|
+
import subprocess
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
from collections import defaultdict
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from contextlib import contextmanager
|
|
23
|
+
from typing import Literal
|
|
24
|
+
|
|
25
|
+
import httpx
|
|
26
|
+
import torch
|
|
27
|
+
import torch.distributed as dist
|
|
28
|
+
from safetensors import safe_open
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from checkpoint_engine.ps import ParameterServer
|
|
32
|
+
from loguru import logger
|
|
33
|
+
except ImportError:
|
|
34
|
+
# Fallback for when checkpoint_engine is not available
|
|
35
|
+
ParameterServer = None
|
|
36
|
+
import logging
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@contextmanager
|
|
42
|
+
def timer(msg: str):
|
|
43
|
+
start = time.perf_counter()
|
|
44
|
+
yield
|
|
45
|
+
end = time.perf_counter()
|
|
46
|
+
logger.info(f"{msg} duration: {end - start:.2f} seconds")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def check_sglang_ready(
|
|
50
|
+
endpoint: str, inference_parallel_size: int, uds: str | None = None
|
|
51
|
+
):
|
|
52
|
+
rank = int(os.getenv("RANK", 0))
|
|
53
|
+
if rank != rank // inference_parallel_size * inference_parallel_size:
|
|
54
|
+
return
|
|
55
|
+
retry_num = 0
|
|
56
|
+
transport = None
|
|
57
|
+
if uds is not None:
|
|
58
|
+
transport = httpx.HTTPTransport(uds=uds)
|
|
59
|
+
with httpx.Client(transport=transport) as client:
|
|
60
|
+
while True:
|
|
61
|
+
try:
|
|
62
|
+
response = client.get(f"{endpoint}/ping", timeout=10)
|
|
63
|
+
response.raise_for_status()
|
|
64
|
+
break
|
|
65
|
+
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
|
66
|
+
if retry_num % 10 == 0:
|
|
67
|
+
logger.warning(
|
|
68
|
+
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
|
|
69
|
+
)
|
|
70
|
+
retry_num += 1
|
|
71
|
+
time.sleep(0.1)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def split_checkpoint_files(
|
|
75
|
+
checkpoint_path: str, rank: int, world_size: int
|
|
76
|
+
) -> list[str]:
|
|
77
|
+
checkpoint_files = [
|
|
78
|
+
os.path.join(checkpoint_path, f)
|
|
79
|
+
for f in filter(
|
|
80
|
+
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
|
|
81
|
+
)
|
|
82
|
+
]
|
|
83
|
+
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
|
|
84
|
+
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def split_tensors(
|
|
88
|
+
checkpoint_path: str, rank: int, world_size: int
|
|
89
|
+
) -> dict[str, torch.Tensor]:
|
|
90
|
+
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
|
|
91
|
+
with open(index_fn) as f:
|
|
92
|
+
weight_map: dict[str, str] = json.load(f)["weight_map"]
|
|
93
|
+
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
|
|
94
|
+
fn_tensors: dict[str, list[str]] = defaultdict(list)
|
|
95
|
+
weight_keys = list(weight_map.items())
|
|
96
|
+
for name, file in weight_keys[
|
|
97
|
+
rank * weights_per_rank : (rank + 1) * weights_per_rank
|
|
98
|
+
]:
|
|
99
|
+
fn_tensors[file].append(name)
|
|
100
|
+
named_tensors = {}
|
|
101
|
+
for file, names in fn_tensors.items():
|
|
102
|
+
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
|
|
103
|
+
for name in names:
|
|
104
|
+
named_tensors[name] = f.get_tensor(name)
|
|
105
|
+
return named_tensors
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def req_inference(
|
|
109
|
+
endpoint: str,
|
|
110
|
+
inference_parallel_size: int,
|
|
111
|
+
timeout: float = 300.0,
|
|
112
|
+
uds: str | None = None,
|
|
113
|
+
weight_version: str | None = None,
|
|
114
|
+
) -> Callable[[list[tuple[str, str]]], None]:
|
|
115
|
+
rank = int(os.getenv("RANK", 0))
|
|
116
|
+
src = rank // inference_parallel_size * inference_parallel_size
|
|
117
|
+
|
|
118
|
+
def req_func(socket_paths: list[tuple[str, str]]):
|
|
119
|
+
if rank == src:
|
|
120
|
+
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
|
|
121
|
+
resp = client.post(
|
|
122
|
+
f"{endpoint}/update_weights_from_ipc",
|
|
123
|
+
json={
|
|
124
|
+
"zmq_handles": dict(
|
|
125
|
+
socket_paths[src : src + inference_parallel_size]
|
|
126
|
+
),
|
|
127
|
+
"flush_cache": True,
|
|
128
|
+
"weight_version": weight_version,
|
|
129
|
+
},
|
|
130
|
+
timeout=timeout,
|
|
131
|
+
)
|
|
132
|
+
resp.raise_for_status()
|
|
133
|
+
|
|
134
|
+
return req_func
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def update_weights(
|
|
138
|
+
ps,
|
|
139
|
+
checkpoint_name: str,
|
|
140
|
+
checkpoint_files: list[str],
|
|
141
|
+
named_tensors: dict[str, torch.Tensor],
|
|
142
|
+
req_func: Callable[[list[tuple[str, str]]], None],
|
|
143
|
+
inference_parallel_size: int,
|
|
144
|
+
endpoint: str,
|
|
145
|
+
save_metas_file: str | None = None,
|
|
146
|
+
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
|
|
147
|
+
uds: str | None = None,
|
|
148
|
+
):
|
|
149
|
+
ps.register_checkpoint(
|
|
150
|
+
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
|
|
151
|
+
)
|
|
152
|
+
ps.init_process_group()
|
|
153
|
+
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
|
154
|
+
dist.barrier()
|
|
155
|
+
with timer("Gather metas"):
|
|
156
|
+
ps.gather_metas(checkpoint_name)
|
|
157
|
+
if save_metas_file and int(os.getenv("RANK")) == 0:
|
|
158
|
+
with open(save_metas_file, "wb") as f:
|
|
159
|
+
pickle.dump(ps.get_metas(), f)
|
|
160
|
+
|
|
161
|
+
if update_method == "broadcast" or update_method == "all":
|
|
162
|
+
with timer("Update weights without setting ranks"):
|
|
163
|
+
ps.update(checkpoint_name, req_func)
|
|
164
|
+
|
|
165
|
+
if update_method == "p2p" or update_method == "all":
|
|
166
|
+
if update_method:
|
|
167
|
+
# sleep 2s to wait destroy process group
|
|
168
|
+
time.sleep(2)
|
|
169
|
+
with timer("Update weights with setting ranks"):
|
|
170
|
+
ps.update(
|
|
171
|
+
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def join(
|
|
176
|
+
ps: ParameterServer,
|
|
177
|
+
checkpoint_name: str,
|
|
178
|
+
load_metas_file: str,
|
|
179
|
+
req_func: Callable[[list[tuple[str, str]]], None],
|
|
180
|
+
inference_parallel_size: int,
|
|
181
|
+
endpoint: str,
|
|
182
|
+
uds: str | None = None,
|
|
183
|
+
):
|
|
184
|
+
assert load_metas_file, "load_metas_file is required"
|
|
185
|
+
with open(load_metas_file, "rb") as f:
|
|
186
|
+
metas = pickle.load(f)
|
|
187
|
+
ps.init_process_group()
|
|
188
|
+
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
|
189
|
+
dist.barrier()
|
|
190
|
+
with timer("Gather metas before join"):
|
|
191
|
+
ps.gather_metas(checkpoint_name)
|
|
192
|
+
ps.load_metas(metas)
|
|
193
|
+
with timer(
|
|
194
|
+
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
|
|
195
|
+
):
|
|
196
|
+
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def run_with_torchrun():
|
|
200
|
+
"""Run the update script with torchrun automatically."""
|
|
201
|
+
# Parse inference_parallel_size from command line arguments to determine nproc-per-node
|
|
202
|
+
inference_parallel_size = 8 # default
|
|
203
|
+
args = sys.argv[1:] # Skip the script name
|
|
204
|
+
|
|
205
|
+
# Look for --inference-parallel-size in arguments
|
|
206
|
+
for i, arg in enumerate(args):
|
|
207
|
+
if arg == "--inference-parallel-size" and i + 1 < len(args):
|
|
208
|
+
try:
|
|
209
|
+
inference_parallel_size = int(args[i + 1])
|
|
210
|
+
except ValueError:
|
|
211
|
+
pass
|
|
212
|
+
break
|
|
213
|
+
elif arg.startswith("--inference-parallel-size="):
|
|
214
|
+
try:
|
|
215
|
+
inference_parallel_size = int(arg.split("=", 1)[1])
|
|
216
|
+
except ValueError:
|
|
217
|
+
pass
|
|
218
|
+
break
|
|
219
|
+
|
|
220
|
+
# Build torchrun command
|
|
221
|
+
cmd = ["torchrun", f"--nproc-per-node={inference_parallel_size}", __file__] + args
|
|
222
|
+
|
|
223
|
+
print(f"Running: {' '.join(cmd)}", file=sys.stderr)
|
|
224
|
+
|
|
225
|
+
# Execute torchrun with the original script
|
|
226
|
+
try:
|
|
227
|
+
result = subprocess.run(cmd, check=False)
|
|
228
|
+
sys.exit(result.returncode)
|
|
229
|
+
except FileNotFoundError:
|
|
230
|
+
print(
|
|
231
|
+
"Error: torchrun command not found. Please ensure PyTorch is installed.",
|
|
232
|
+
file=sys.stderr,
|
|
233
|
+
)
|
|
234
|
+
sys.exit(1)
|
|
235
|
+
except KeyboardInterrupt:
|
|
236
|
+
print("\nInterrupted by user", file=sys.stderr)
|
|
237
|
+
sys.exit(130)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def main():
|
|
241
|
+
# Check if we're running under torchrun or need to invoke it
|
|
242
|
+
if os.getenv("RANK") is None:
|
|
243
|
+
# Not running under torchrun, so invoke it
|
|
244
|
+
run_with_torchrun()
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
# Running under torchrun, proceed with normal execution
|
|
248
|
+
parser = argparse.ArgumentParser(description="Update weights example")
|
|
249
|
+
parser.add_argument("--checkpoint-path", type=str, default=None)
|
|
250
|
+
parser.add_argument("--save-metas-file", type=str, default=None)
|
|
251
|
+
parser.add_argument("--load-metas-file", type=str, default=None)
|
|
252
|
+
parser.add_argument("--sleep-time", type=int, default=0)
|
|
253
|
+
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
|
|
254
|
+
parser.add_argument("--inference-parallel-size", type=int, default=8)
|
|
255
|
+
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
|
|
256
|
+
parser.add_argument("--update-method", type=str, default="broadcast")
|
|
257
|
+
parser.add_argument("--uds", type=str, default=None)
|
|
258
|
+
parser.add_argument("--weight-version", type=str, default=None)
|
|
259
|
+
args = parser.parse_args()
|
|
260
|
+
|
|
261
|
+
# Get rank and world_size from environment (set by torchrun)
|
|
262
|
+
rank = int(os.getenv("RANK", 0))
|
|
263
|
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
|
264
|
+
|
|
265
|
+
req_func = req_inference(
|
|
266
|
+
args.endpoint,
|
|
267
|
+
args.inference_parallel_size,
|
|
268
|
+
uds=args.uds,
|
|
269
|
+
weight_version=args.weight_version,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if ParameterServer is None:
|
|
273
|
+
print("Error: checkpoint_engine package not available", file=sys.stderr)
|
|
274
|
+
sys.exit(1)
|
|
275
|
+
|
|
276
|
+
ps = ParameterServer(auto_pg=True)
|
|
277
|
+
ps._p2p_store = None
|
|
278
|
+
if args.load_metas_file:
|
|
279
|
+
join(
|
|
280
|
+
ps,
|
|
281
|
+
args.checkpoint_name,
|
|
282
|
+
args.load_metas_file,
|
|
283
|
+
req_func,
|
|
284
|
+
args.inference_parallel_size,
|
|
285
|
+
args.endpoint,
|
|
286
|
+
args.uds,
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
if args.checkpoint_path and os.path.exists(
|
|
290
|
+
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
|
|
291
|
+
):
|
|
292
|
+
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
|
|
293
|
+
checkpoint_files = []
|
|
294
|
+
else:
|
|
295
|
+
checkpoint_files = (
|
|
296
|
+
split_checkpoint_files(args.checkpoint_path, rank, world_size)
|
|
297
|
+
if args.checkpoint_path
|
|
298
|
+
else []
|
|
299
|
+
)
|
|
300
|
+
named_tensors = {}
|
|
301
|
+
update_weights(
|
|
302
|
+
ps,
|
|
303
|
+
args.checkpoint_name,
|
|
304
|
+
checkpoint_files,
|
|
305
|
+
named_tensors,
|
|
306
|
+
req_func,
|
|
307
|
+
args.inference_parallel_size,
|
|
308
|
+
args.endpoint,
|
|
309
|
+
args.save_metas_file,
|
|
310
|
+
args.update_method,
|
|
311
|
+
args.uds,
|
|
312
|
+
)
|
|
313
|
+
time.sleep(args.sleep_time)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
if __name__ == "__main__":
|
|
317
|
+
main()
|
sglang/srt/configs/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
|
|
6
6
|
from sglang.srt.configs.exaone import ExaoneConfig
|
|
7
7
|
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
|
8
8
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
|
9
|
+
from sglang.srt.configs.kimi_linear import KimiLinearConfig
|
|
9
10
|
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
|
10
11
|
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
|
11
12
|
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
|
@@ -31,6 +32,7 @@ __all__ = [
|
|
|
31
32
|
"Step3TextConfig",
|
|
32
33
|
"Step3VisionEncoderConfig",
|
|
33
34
|
"Olmo3Config",
|
|
35
|
+
"KimiLinearConfig",
|
|
34
36
|
"Qwen3NextConfig",
|
|
35
37
|
"DotsVLMConfig",
|
|
36
38
|
"DotsOCRConfig",
|