sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -14,35 +14,30 @@
|
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
16
|
import gc
|
17
|
-
import importlib
|
18
|
-
import importlib.resources
|
19
|
-
import inspect
|
20
17
|
import json
|
21
18
|
import logging
|
22
|
-
import
|
23
|
-
from
|
24
|
-
from typing import Optional, Type
|
19
|
+
import time
|
20
|
+
from typing import Optional
|
25
21
|
|
26
22
|
import torch
|
27
|
-
import torch.
|
28
|
-
from vllm.config import DeviceConfig, LoadConfig
|
29
|
-
from vllm.config import ModelConfig as VllmModelConfig
|
23
|
+
import torch.distributed as dist
|
30
24
|
from vllm.distributed import (
|
31
25
|
get_tp_group,
|
32
26
|
init_distributed_environment,
|
33
27
|
initialize_model_parallel,
|
34
28
|
set_custom_all_reduce,
|
35
29
|
)
|
36
|
-
from vllm.distributed.parallel_state import in_the_same_node_as
|
37
|
-
from vllm.model_executor.model_loader import get_model
|
38
|
-
from vllm.model_executor.models import ModelRegistry
|
39
30
|
|
31
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
32
|
+
from sglang.srt.configs.load_config import LoadConfig
|
40
33
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
34
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
42
35
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
36
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
43
37
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
44
38
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
45
39
|
from sglang.srt.layers.sampler import Sampler
|
40
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
46
41
|
from sglang.srt.lora.lora_manager import LoRAManager
|
47
42
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
43
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -52,14 +47,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
52
47
|
ReqToTokenPool,
|
53
48
|
)
|
54
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
|
+
from sglang.srt.model_loader import get_model
|
55
51
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
56
52
|
from sglang.srt.server_args import ServerArgs
|
57
53
|
from sglang.srt.utils import (
|
58
|
-
crash_on_warnings,
|
59
54
|
enable_show_time_cost,
|
60
55
|
get_available_gpu_memory,
|
56
|
+
init_custom_process_group,
|
61
57
|
is_hip,
|
62
|
-
|
58
|
+
monkey_patch_vllm_gguf_config,
|
63
59
|
monkey_patch_vllm_p2p_access_check,
|
64
60
|
set_cpu_offload_max_bytes,
|
65
61
|
)
|
@@ -115,11 +111,13 @@ class ModelRunner:
|
|
115
111
|
)
|
116
112
|
|
117
113
|
if self.is_multimodal:
|
114
|
+
server_args.chunked_prefill_size = -1
|
115
|
+
self.mem_fraction_static *= 0.95
|
118
116
|
logger.info(
|
119
|
-
"Automatically
|
117
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
|
118
|
+
f"and turn off chunked prefill "
|
119
|
+
f"because this is a multimodal model."
|
120
120
|
)
|
121
|
-
server_args.chunked_prefill_size = None
|
122
|
-
self.mem_fraction_static *= 0.95
|
123
121
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
124
122
|
if self.model_config.hf_config.architectures == [
|
125
123
|
"Qwen2VLForConditionalGeneration"
|
@@ -129,7 +127,7 @@ class ModelRunner:
|
|
129
127
|
# Global vars
|
130
128
|
if server_args.show_time_cost:
|
131
129
|
enable_show_time_cost()
|
132
|
-
if server_args.
|
130
|
+
if server_args.disable_outlines_disk_cache:
|
133
131
|
from outlines.caching import disable_cache
|
134
132
|
|
135
133
|
disable_cache()
|
@@ -143,17 +141,20 @@ class ModelRunner:
|
|
143
141
|
"torchao_config": server_args.torchao_config,
|
144
142
|
"enable_nan_detection": server_args.enable_nan_detection,
|
145
143
|
"enable_dp_attention": server_args.enable_dp_attention,
|
144
|
+
"enable_ep_moe": server_args.enable_ep_moe,
|
146
145
|
}
|
147
146
|
)
|
148
147
|
|
149
148
|
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
150
149
|
|
151
|
-
#
|
150
|
+
# Get memory before model loading
|
152
151
|
min_per_gpu_memory = self.init_torch_distributed()
|
152
|
+
|
153
|
+
# Load the model
|
153
154
|
self.sampler = Sampler()
|
154
155
|
self.load_model()
|
155
156
|
|
156
|
-
# Apply torch TP if model supports it
|
157
|
+
# Apply torch TP if the model supports it
|
157
158
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
158
159
|
if self.tp_size > 1 and supports_torch_tp:
|
159
160
|
self.apply_torch_tp()
|
@@ -161,6 +162,11 @@ class ModelRunner:
|
|
161
162
|
else:
|
162
163
|
self.torch_tp_applied = False
|
163
164
|
|
165
|
+
apply_torchao_config_to_model(
|
166
|
+
self.model, global_server_args_dict["torchao_config"]
|
167
|
+
)
|
168
|
+
|
169
|
+
# Init memory pool and attention backends
|
164
170
|
if server_args.lora_paths is not None:
|
165
171
|
self.init_lora_manager()
|
166
172
|
self.init_memory_pool(
|
@@ -209,16 +215,6 @@ class ModelRunner:
|
|
209
215
|
)
|
210
216
|
self.tp_group = get_tp_group()
|
211
217
|
|
212
|
-
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
213
|
-
# so we disable padding in cuda graph.
|
214
|
-
if self.device == "cuda" and not all(
|
215
|
-
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
216
|
-
):
|
217
|
-
self.server_args.disable_cuda_graph_padding = True
|
218
|
-
logger.info(
|
219
|
-
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
220
|
-
)
|
221
|
-
|
222
218
|
# Check memory for tensor parallelism
|
223
219
|
if self.tp_size > 1:
|
224
220
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
@@ -229,49 +225,6 @@ class ModelRunner:
|
|
229
225
|
|
230
226
|
return min_per_gpu_memory
|
231
227
|
|
232
|
-
def setup_model(self):
|
233
|
-
try:
|
234
|
-
from vllm.config import VllmConfig
|
235
|
-
|
236
|
-
vllm_config = VllmConfig()
|
237
|
-
vllm_config.model_config = self.vllm_model_config
|
238
|
-
vllm_config.load_config = self.load_config
|
239
|
-
vllm_config.device_config = DeviceConfig(self.device)
|
240
|
-
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
241
|
-
vllm_config.model_config, vllm_config.load_config
|
242
|
-
)
|
243
|
-
return get_model(vllm_config=vllm_config)
|
244
|
-
except ImportError:
|
245
|
-
pass
|
246
|
-
|
247
|
-
return get_model(
|
248
|
-
model_config=self.vllm_model_config,
|
249
|
-
load_config=self.load_config,
|
250
|
-
device_config=DeviceConfig(self.device),
|
251
|
-
parallel_config=None,
|
252
|
-
scheduler_config=None,
|
253
|
-
lora_config=None,
|
254
|
-
cache_config=None,
|
255
|
-
)
|
256
|
-
|
257
|
-
def get_model_config_params(self):
|
258
|
-
sig = inspect.signature(VllmModelConfig.__init__)
|
259
|
-
params = {
|
260
|
-
"model": self.server_args.model_path,
|
261
|
-
"quantization": self.server_args.quantization,
|
262
|
-
"tokenizer": None,
|
263
|
-
"tokenizer_mode": None,
|
264
|
-
"trust_remote_code": self.server_args.trust_remote_code,
|
265
|
-
"dtype": self.server_args.dtype,
|
266
|
-
"seed": self.server_args.random_seed,
|
267
|
-
"skip_tokenizer_init": True,
|
268
|
-
}
|
269
|
-
|
270
|
-
if "task" in sig.parameters:
|
271
|
-
params["task"] = ""
|
272
|
-
|
273
|
-
return params
|
274
|
-
|
275
228
|
def load_model(self):
|
276
229
|
logger.info(
|
277
230
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -285,6 +238,7 @@ class ModelRunner:
|
|
285
238
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
286
239
|
)
|
287
240
|
self.server_args.dtype = "float16"
|
241
|
+
self.model_config.dtype = torch.float16
|
288
242
|
if torch.cuda.get_device_capability()[1] < 5:
|
289
243
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
290
244
|
|
@@ -293,21 +247,21 @@ class ModelRunner:
|
|
293
247
|
load_format=self.server_args.load_format,
|
294
248
|
download_dir=self.server_args.download_dir,
|
295
249
|
)
|
296
|
-
monkey_patch_vllm_model_config()
|
297
|
-
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
298
|
-
if self.model_config.model_override_args is not None:
|
299
|
-
self.vllm_model_config.hf_config.update(
|
300
|
-
self.model_config.model_override_args
|
301
|
-
)
|
302
250
|
|
303
|
-
self.
|
251
|
+
if self.server_args.load_format == "gguf":
|
252
|
+
monkey_patch_vllm_gguf_config()
|
253
|
+
self.model = get_model(
|
254
|
+
model_config=self.model_config,
|
255
|
+
load_config=self.load_config,
|
256
|
+
device_config=DeviceConfig(self.device),
|
257
|
+
)
|
304
258
|
|
305
259
|
self.sliding_window_size = (
|
306
260
|
self.model.get_attention_sliding_window_size()
|
307
261
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
308
262
|
else None
|
309
263
|
)
|
310
|
-
self.dtype = self.
|
264
|
+
self.dtype = self.model_config.dtype
|
311
265
|
|
312
266
|
logger.info(
|
313
267
|
f"Load weight end. "
|
@@ -316,30 +270,22 @@ class ModelRunner:
|
|
316
270
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
317
271
|
)
|
318
272
|
|
319
|
-
def
|
320
|
-
"""Update weights
|
321
|
-
from
|
273
|
+
def update_weights_from_disk(self, model_path: str, load_format: str):
|
274
|
+
"""Update engine weights online from disk."""
|
275
|
+
from sglang.srt.model_loader.loader import (
|
322
276
|
DefaultModelLoader,
|
323
277
|
device_loading_context,
|
324
278
|
get_model_loader,
|
325
279
|
)
|
326
|
-
from
|
280
|
+
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
327
281
|
|
328
282
|
logger.info(
|
329
|
-
f"Update weights begin. "
|
283
|
+
f"Update engine weights online from disk begin. "
|
330
284
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
331
285
|
)
|
332
286
|
|
333
287
|
target_device = torch.device(self.device)
|
334
|
-
|
335
|
-
try:
|
336
|
-
model_config_params = self.get_model_config_params()
|
337
|
-
model_config_params["model"] = model_path
|
338
|
-
vllm_model_config = VllmModelConfig(**model_config_params)
|
339
|
-
except Exception as e:
|
340
|
-
message = f"Failed to load model config: {e}."
|
341
|
-
return False, message
|
342
|
-
|
288
|
+
self.model_config.model_path = model_path
|
343
289
|
load_config = LoadConfig(load_format=load_format)
|
344
290
|
|
345
291
|
# Only support vllm DefaultModelLoader for now
|
@@ -351,7 +297,7 @@ class ModelRunner:
|
|
351
297
|
def get_weight_iter(config):
|
352
298
|
iter = loader._get_weights_iterator(
|
353
299
|
DefaultModelLoader.Source(
|
354
|
-
config.
|
300
|
+
config.model_path,
|
355
301
|
revision=config.revision,
|
356
302
|
fall_back_to_pt=getattr(
|
357
303
|
self.model, "fall_back_to_pt_during_load", True
|
@@ -369,9 +315,9 @@ class ModelRunner:
|
|
369
315
|
quant_method.process_weights_after_loading(module)
|
370
316
|
return model
|
371
317
|
|
372
|
-
with set_default_torch_dtype(
|
318
|
+
with set_default_torch_dtype(self.model_config.dtype):
|
373
319
|
try:
|
374
|
-
iter = get_weight_iter(
|
320
|
+
iter = get_weight_iter(self.model_config)
|
375
321
|
except Exception as e:
|
376
322
|
message = f"Failed to get weights iterator: {e}."
|
377
323
|
return False, message
|
@@ -383,20 +329,115 @@ class ModelRunner:
|
|
383
329
|
)
|
384
330
|
del iter
|
385
331
|
gc.collect()
|
386
|
-
iter = get_weight_iter(self.
|
332
|
+
iter = get_weight_iter(self.model_config)
|
387
333
|
self.model = model_load_weights(self.model, iter)
|
388
334
|
return False, message
|
389
335
|
|
390
336
|
self.model = model
|
391
337
|
self.server_args.model_path = model_path
|
392
338
|
self.server_args.load_format = load_format
|
393
|
-
self.vllm_model_config = vllm_model_config
|
394
339
|
self.load_config = load_config
|
395
|
-
self.model_config.path = model_path
|
396
340
|
|
397
341
|
logger.info("Update weights end.")
|
398
342
|
return True, "Succeeded to update model weights."
|
399
343
|
|
344
|
+
def init_weights_update_group(
|
345
|
+
self,
|
346
|
+
master_address,
|
347
|
+
master_port,
|
348
|
+
rank_offset,
|
349
|
+
world_size,
|
350
|
+
group_name,
|
351
|
+
backend="nccl",
|
352
|
+
):
|
353
|
+
"""Initialize the Torch process group for model parameter updates.
|
354
|
+
|
355
|
+
`_model_update_group` is used in the RLHF workflow, where rank
|
356
|
+
0 is the actor model in the training engine, and the other ranks are
|
357
|
+
the inference engine, which is used for rollout.
|
358
|
+
|
359
|
+
In the RLHF workflow, the training engine updates the model
|
360
|
+
weights/parameters online, and broadcasts them to the inference
|
361
|
+
engine through the `_model_update_group` process group.
|
362
|
+
"""
|
363
|
+
assert (
|
364
|
+
torch.distributed.is_initialized()
|
365
|
+
), "Default torch process group must be initialized"
|
366
|
+
assert group_name != "", "Group name cannot be empty"
|
367
|
+
|
368
|
+
rank = rank_offset + self.tp_rank
|
369
|
+
|
370
|
+
logger.info(
|
371
|
+
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
372
|
+
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
373
|
+
)
|
374
|
+
|
375
|
+
try:
|
376
|
+
self._model_update_group = init_custom_process_group(
|
377
|
+
backend=backend,
|
378
|
+
init_method=f"tcp://{master_address}:{master_port}",
|
379
|
+
world_size=world_size,
|
380
|
+
rank=rank,
|
381
|
+
group_name=group_name,
|
382
|
+
)
|
383
|
+
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
384
|
+
return True, "Succeeded to initialize custom process group."
|
385
|
+
except Exception as e:
|
386
|
+
message = f"Failed to initialize custom process group: {e}."
|
387
|
+
logger.error(message)
|
388
|
+
return False, message
|
389
|
+
|
390
|
+
def update_weights_from_distributed(self, name, dtype, shape):
|
391
|
+
"""
|
392
|
+
Update specific parameter in the model weights online
|
393
|
+
through `_model_update_group` process group.
|
394
|
+
|
395
|
+
Args:
|
396
|
+
name: the name of the parameter to be updated.
|
397
|
+
dtype: the data type of the parameter to be updated.
|
398
|
+
shape: the shape of the parameter to be updated.
|
399
|
+
"""
|
400
|
+
target_dtype = (
|
401
|
+
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
402
|
+
)
|
403
|
+
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
404
|
+
|
405
|
+
assert (
|
406
|
+
self._model_update_group is not None
|
407
|
+
), "model update group must be initialized"
|
408
|
+
|
409
|
+
try:
|
410
|
+
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
411
|
+
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
412
|
+
self.model.load_weights([(name, weights)])
|
413
|
+
return True, f"Succeeded to update parameter {name} online."
|
414
|
+
|
415
|
+
except Exception as e:
|
416
|
+
error_msg = (
|
417
|
+
f"Failed to update parameter online: {e}. "
|
418
|
+
f"The full weights of the ModelRunner are partially updated. "
|
419
|
+
f"Please discard the whole weights."
|
420
|
+
)
|
421
|
+
logger.error(error_msg)
|
422
|
+
return False, error_msg
|
423
|
+
|
424
|
+
def get_weights_by_name(
|
425
|
+
self, name: str, truncate_size: int = 100
|
426
|
+
) -> Optional[torch.Tensor]:
|
427
|
+
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
428
|
+
|
429
|
+
Only used for unit test with an unoptimized performance.
|
430
|
+
For optimized performance, please use torch.save and torch.load.
|
431
|
+
"""
|
432
|
+
# TODO: (chenyang) Add support for Qwen models.
|
433
|
+
try:
|
434
|
+
return self.model.get_weights_by_name(
|
435
|
+
name, truncate_size, tp_size=self.tp_size
|
436
|
+
)
|
437
|
+
except Exception as e:
|
438
|
+
logger.error(f"Error when getting parameter {name}: {e}")
|
439
|
+
return None
|
440
|
+
|
400
441
|
def init_lora_manager(self):
|
401
442
|
self.lora_manager = LoRAManager(
|
402
443
|
base_model=self.model,
|
@@ -547,6 +588,8 @@ class ModelRunner:
|
|
547
588
|
self.attn_backend = DoubleSparseAttnBackend(self)
|
548
589
|
else:
|
549
590
|
self.attn_backend = TritonAttnBackend(self)
|
591
|
+
elif self.server_args.attention_backend == "torch_native":
|
592
|
+
self.attn_backend = TorchNativeAttnBackend(self)
|
550
593
|
else:
|
551
594
|
raise ValueError(
|
552
595
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -583,8 +626,10 @@ class ModelRunner:
|
|
583
626
|
if self.server_args.disable_cuda_graph:
|
584
627
|
return
|
585
628
|
|
629
|
+
tic = time.time()
|
586
630
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
587
631
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
632
|
+
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
588
633
|
|
589
634
|
def apply_torch_tp(self):
|
590
635
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -694,55 +739,3 @@ class ModelRunner:
|
|
694
739
|
if rope_scaling is None:
|
695
740
|
return False
|
696
741
|
return rope_scaling.get("type", None) == "mrope"
|
697
|
-
|
698
|
-
|
699
|
-
@lru_cache()
|
700
|
-
def import_model_classes():
|
701
|
-
model_arch_name_to_cls = {}
|
702
|
-
package_name = "sglang.srt.models"
|
703
|
-
package = importlib.import_module(package_name)
|
704
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
705
|
-
if not ispkg:
|
706
|
-
try:
|
707
|
-
module = importlib.import_module(name)
|
708
|
-
except Exception as e:
|
709
|
-
logger.warning(f"Ignore import error when loading {name}. {e}")
|
710
|
-
if crash_on_warnings():
|
711
|
-
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
712
|
-
continue
|
713
|
-
if hasattr(module, "EntryClass"):
|
714
|
-
entry = module.EntryClass
|
715
|
-
if isinstance(
|
716
|
-
entry, list
|
717
|
-
): # To support multiple model classes in one module
|
718
|
-
for tmp in entry:
|
719
|
-
assert (
|
720
|
-
tmp.__name__ not in model_arch_name_to_cls
|
721
|
-
), f"Duplicated model implementation for {tmp.__name__}"
|
722
|
-
model_arch_name_to_cls[tmp.__name__] = tmp
|
723
|
-
else:
|
724
|
-
assert (
|
725
|
-
entry.__name__ not in model_arch_name_to_cls
|
726
|
-
), f"Duplicated model implementation for {entry.__name__}"
|
727
|
-
model_arch_name_to_cls[entry.__name__] = entry
|
728
|
-
|
729
|
-
return model_arch_name_to_cls
|
730
|
-
|
731
|
-
|
732
|
-
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
733
|
-
model_arch_name_to_cls = import_model_classes()
|
734
|
-
|
735
|
-
if model_arch not in model_arch_name_to_cls:
|
736
|
-
raise ValueError(
|
737
|
-
f"Unsupported architectures: {model_arch}. "
|
738
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
739
|
-
)
|
740
|
-
return model_arch_name_to_cls[model_arch]
|
741
|
-
|
742
|
-
|
743
|
-
# Monkey patch model loader
|
744
|
-
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
745
|
-
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
746
|
-
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
747
|
-
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
748
|
-
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
2
|
+
|
3
|
+
from torch import nn
|
4
|
+
|
5
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
6
|
+
from sglang.srt.configs.load_config import LoadConfig
|
7
|
+
from sglang.srt.configs.model_config import ModelConfig
|
8
|
+
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
9
|
+
from sglang.srt.model_loader.utils import (
|
10
|
+
get_architecture_class_name,
|
11
|
+
get_model_architecture,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def get_model(
|
16
|
+
*,
|
17
|
+
model_config: ModelConfig,
|
18
|
+
load_config: LoadConfig,
|
19
|
+
device_config: DeviceConfig,
|
20
|
+
) -> nn.Module:
|
21
|
+
loader = get_model_loader(load_config)
|
22
|
+
return loader.load_model(
|
23
|
+
model_config=model_config,
|
24
|
+
device_config=device_config,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
"get_model",
|
30
|
+
"get_model_loader",
|
31
|
+
"BaseModelLoader",
|
32
|
+
"get_architecture_class_name",
|
33
|
+
"get_model_architecture",
|
34
|
+
]
|