sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +18 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +76 -20
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +267 -170
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +245 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -14,19 +14,13 @@
|
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
16
|
import gc
|
17
|
-
import importlib
|
18
|
-
import importlib.resources
|
19
|
-
import inspect
|
20
17
|
import json
|
21
18
|
import logging
|
22
|
-
import
|
23
|
-
from
|
24
|
-
from typing import Optional, Type
|
19
|
+
import time
|
20
|
+
from typing import Optional
|
25
21
|
|
26
22
|
import torch
|
27
|
-
import torch.
|
28
|
-
from vllm.config import DeviceConfig, LoadConfig
|
29
|
-
from vllm.config import ModelConfig as VllmModelConfig
|
23
|
+
import torch.distributed as dist
|
30
24
|
from vllm.distributed import (
|
31
25
|
get_tp_group,
|
32
26
|
init_distributed_environment,
|
@@ -34,12 +28,13 @@ from vllm.distributed import (
|
|
34
28
|
set_custom_all_reduce,
|
35
29
|
)
|
36
30
|
from vllm.distributed.parallel_state import in_the_same_node_as
|
37
|
-
from vllm.model_executor.model_loader import get_model
|
38
|
-
from vllm.model_executor.models import ModelRegistry
|
39
31
|
|
32
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
33
|
+
from sglang.srt.configs.load_config import LoadConfig
|
40
34
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
35
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
42
36
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
37
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
43
38
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
44
39
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
45
40
|
from sglang.srt.layers.sampler import Sampler
|
@@ -52,14 +47,15 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
52
47
|
ReqToTokenPool,
|
53
48
|
)
|
54
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
|
+
from sglang.srt.model_loader import get_model
|
55
51
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
56
52
|
from sglang.srt.server_args import ServerArgs
|
57
53
|
from sglang.srt.utils import (
|
58
|
-
crash_on_warnings,
|
59
54
|
enable_show_time_cost,
|
60
55
|
get_available_gpu_memory,
|
56
|
+
init_custom_process_group,
|
61
57
|
is_hip,
|
62
|
-
|
58
|
+
monkey_patch_vllm_gguf_config,
|
63
59
|
monkey_patch_vllm_p2p_access_check,
|
64
60
|
set_cpu_offload_max_bytes,
|
65
61
|
)
|
@@ -118,7 +114,7 @@ class ModelRunner:
|
|
118
114
|
logger.info(
|
119
115
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
120
116
|
)
|
121
|
-
server_args.chunked_prefill_size =
|
117
|
+
server_args.chunked_prefill_size = -1
|
122
118
|
self.mem_fraction_static *= 0.95
|
123
119
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
124
120
|
if self.model_config.hf_config.architectures == [
|
@@ -129,7 +125,7 @@ class ModelRunner:
|
|
129
125
|
# Global vars
|
130
126
|
if server_args.show_time_cost:
|
131
127
|
enable_show_time_cost()
|
132
|
-
if server_args.
|
128
|
+
if server_args.disable_outlines_disk_cache:
|
133
129
|
from outlines.caching import disable_cache
|
134
130
|
|
135
131
|
disable_cache()
|
@@ -148,12 +144,14 @@ class ModelRunner:
|
|
148
144
|
|
149
145
|
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
150
146
|
|
151
|
-
#
|
147
|
+
# Get memory before model loading
|
152
148
|
min_per_gpu_memory = self.init_torch_distributed()
|
149
|
+
|
150
|
+
# Load the model
|
153
151
|
self.sampler = Sampler()
|
154
152
|
self.load_model()
|
155
153
|
|
156
|
-
# Apply torch TP if model supports it
|
154
|
+
# Apply torch TP if the model supports it
|
157
155
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
158
156
|
if self.tp_size > 1 and supports_torch_tp:
|
159
157
|
self.apply_torch_tp()
|
@@ -161,6 +159,7 @@ class ModelRunner:
|
|
161
159
|
else:
|
162
160
|
self.torch_tp_applied = False
|
163
161
|
|
162
|
+
# Init memory pool and attention backends
|
164
163
|
if server_args.lora_paths is not None:
|
165
164
|
self.init_lora_manager()
|
166
165
|
self.init_memory_pool(
|
@@ -209,16 +208,6 @@ class ModelRunner:
|
|
209
208
|
)
|
210
209
|
self.tp_group = get_tp_group()
|
211
210
|
|
212
|
-
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
|
213
|
-
# so we disable padding in cuda graph.
|
214
|
-
if self.device == "cuda" and not all(
|
215
|
-
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
|
216
|
-
):
|
217
|
-
self.server_args.disable_cuda_graph_padding = True
|
218
|
-
logger.info(
|
219
|
-
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
|
220
|
-
)
|
221
|
-
|
222
211
|
# Check memory for tensor parallelism
|
223
212
|
if self.tp_size > 1:
|
224
213
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
@@ -229,49 +218,6 @@ class ModelRunner:
|
|
229
218
|
|
230
219
|
return min_per_gpu_memory
|
231
220
|
|
232
|
-
def setup_model(self):
|
233
|
-
try:
|
234
|
-
from vllm.config import VllmConfig
|
235
|
-
|
236
|
-
vllm_config = VllmConfig()
|
237
|
-
vllm_config.model_config = self.vllm_model_config
|
238
|
-
vllm_config.load_config = self.load_config
|
239
|
-
vllm_config.device_config = DeviceConfig(self.device)
|
240
|
-
vllm_config.quant_config = VllmConfig._get_quantization_config(
|
241
|
-
vllm_config.model_config, vllm_config.load_config
|
242
|
-
)
|
243
|
-
return get_model(vllm_config=vllm_config)
|
244
|
-
except ImportError:
|
245
|
-
pass
|
246
|
-
|
247
|
-
return get_model(
|
248
|
-
model_config=self.vllm_model_config,
|
249
|
-
load_config=self.load_config,
|
250
|
-
device_config=DeviceConfig(self.device),
|
251
|
-
parallel_config=None,
|
252
|
-
scheduler_config=None,
|
253
|
-
lora_config=None,
|
254
|
-
cache_config=None,
|
255
|
-
)
|
256
|
-
|
257
|
-
def get_model_config_params(self):
|
258
|
-
sig = inspect.signature(VllmModelConfig.__init__)
|
259
|
-
params = {
|
260
|
-
"model": self.server_args.model_path,
|
261
|
-
"quantization": self.server_args.quantization,
|
262
|
-
"tokenizer": None,
|
263
|
-
"tokenizer_mode": None,
|
264
|
-
"trust_remote_code": self.server_args.trust_remote_code,
|
265
|
-
"dtype": self.server_args.dtype,
|
266
|
-
"seed": self.server_args.random_seed,
|
267
|
-
"skip_tokenizer_init": True,
|
268
|
-
}
|
269
|
-
|
270
|
-
if "task" in sig.parameters:
|
271
|
-
params["task"] = ""
|
272
|
-
|
273
|
-
return params
|
274
|
-
|
275
221
|
def load_model(self):
|
276
222
|
logger.info(
|
277
223
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -285,6 +231,7 @@ class ModelRunner:
|
|
285
231
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
286
232
|
)
|
287
233
|
self.server_args.dtype = "float16"
|
234
|
+
self.model_config.dtype = torch.float16
|
288
235
|
if torch.cuda.get_device_capability()[1] < 5:
|
289
236
|
raise RuntimeError("SGLang only supports sm75 and above.")
|
290
237
|
|
@@ -293,21 +240,21 @@ class ModelRunner:
|
|
293
240
|
load_format=self.server_args.load_format,
|
294
241
|
download_dir=self.server_args.download_dir,
|
295
242
|
)
|
296
|
-
monkey_patch_vllm_model_config()
|
297
|
-
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
|
298
|
-
if self.model_config.model_override_args is not None:
|
299
|
-
self.vllm_model_config.hf_config.update(
|
300
|
-
self.model_config.model_override_args
|
301
|
-
)
|
302
243
|
|
303
|
-
self.
|
244
|
+
if self.server_args.load_format == "gguf":
|
245
|
+
monkey_patch_vllm_gguf_config()
|
246
|
+
self.model = get_model(
|
247
|
+
model_config=self.model_config,
|
248
|
+
load_config=self.load_config,
|
249
|
+
device_config=DeviceConfig(self.device),
|
250
|
+
)
|
304
251
|
|
305
252
|
self.sliding_window_size = (
|
306
253
|
self.model.get_attention_sliding_window_size()
|
307
254
|
if hasattr(self.model, "get_attention_sliding_window_size")
|
308
255
|
else None
|
309
256
|
)
|
310
|
-
self.dtype = self.
|
257
|
+
self.dtype = self.model_config.dtype
|
311
258
|
|
312
259
|
logger.info(
|
313
260
|
f"Load weight end. "
|
@@ -316,30 +263,22 @@ class ModelRunner:
|
|
316
263
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
317
264
|
)
|
318
265
|
|
319
|
-
def
|
320
|
-
"""Update weights
|
321
|
-
from
|
266
|
+
def update_weights_from_disk(self, model_path: str, load_format: str):
|
267
|
+
"""Update engine weights online from disk."""
|
268
|
+
from sglang.srt.model_loader.loader import (
|
322
269
|
DefaultModelLoader,
|
323
270
|
device_loading_context,
|
324
271
|
get_model_loader,
|
325
272
|
)
|
326
|
-
from
|
273
|
+
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
327
274
|
|
328
275
|
logger.info(
|
329
|
-
f"Update weights begin. "
|
276
|
+
f"Update engine weights online from disk begin. "
|
330
277
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
331
278
|
)
|
332
279
|
|
333
280
|
target_device = torch.device(self.device)
|
334
|
-
|
335
|
-
try:
|
336
|
-
model_config_params = self.get_model_config_params()
|
337
|
-
model_config_params["model"] = model_path
|
338
|
-
vllm_model_config = VllmModelConfig(**model_config_params)
|
339
|
-
except Exception as e:
|
340
|
-
message = f"Failed to load model config: {e}."
|
341
|
-
return False, message
|
342
|
-
|
281
|
+
self.model_config.model_path = model_path
|
343
282
|
load_config = LoadConfig(load_format=load_format)
|
344
283
|
|
345
284
|
# Only support vllm DefaultModelLoader for now
|
@@ -351,7 +290,7 @@ class ModelRunner:
|
|
351
290
|
def get_weight_iter(config):
|
352
291
|
iter = loader._get_weights_iterator(
|
353
292
|
DefaultModelLoader.Source(
|
354
|
-
config.
|
293
|
+
config.model_path,
|
355
294
|
revision=config.revision,
|
356
295
|
fall_back_to_pt=getattr(
|
357
296
|
self.model, "fall_back_to_pt_during_load", True
|
@@ -369,9 +308,9 @@ class ModelRunner:
|
|
369
308
|
quant_method.process_weights_after_loading(module)
|
370
309
|
return model
|
371
310
|
|
372
|
-
with set_default_torch_dtype(
|
311
|
+
with set_default_torch_dtype(self.model_config.dtype):
|
373
312
|
try:
|
374
|
-
iter = get_weight_iter(
|
313
|
+
iter = get_weight_iter(self.model_config)
|
375
314
|
except Exception as e:
|
376
315
|
message = f"Failed to get weights iterator: {e}."
|
377
316
|
return False, message
|
@@ -383,20 +322,115 @@ class ModelRunner:
|
|
383
322
|
)
|
384
323
|
del iter
|
385
324
|
gc.collect()
|
386
|
-
iter = get_weight_iter(self.
|
325
|
+
iter = get_weight_iter(self.model_config)
|
387
326
|
self.model = model_load_weights(self.model, iter)
|
388
327
|
return False, message
|
389
328
|
|
390
329
|
self.model = model
|
391
330
|
self.server_args.model_path = model_path
|
392
331
|
self.server_args.load_format = load_format
|
393
|
-
self.vllm_model_config = vllm_model_config
|
394
332
|
self.load_config = load_config
|
395
|
-
self.model_config.path = model_path
|
396
333
|
|
397
334
|
logger.info("Update weights end.")
|
398
335
|
return True, "Succeeded to update model weights."
|
399
336
|
|
337
|
+
def init_weights_update_group(
|
338
|
+
self,
|
339
|
+
master_address,
|
340
|
+
master_port,
|
341
|
+
rank_offset,
|
342
|
+
world_size,
|
343
|
+
group_name,
|
344
|
+
backend="nccl",
|
345
|
+
):
|
346
|
+
"""Initialize the Torch process group for model parameter updates.
|
347
|
+
|
348
|
+
`_model_update_group` is used in the RLHF workflow, where rank
|
349
|
+
0 is the actor model in the training engine, and the other ranks are
|
350
|
+
the inference engine, which is used for rollout.
|
351
|
+
|
352
|
+
In the RLHF workflow, the training engine updates the model
|
353
|
+
weights/parameters online, and broadcasts them to the inference
|
354
|
+
engine through the `_model_update_group` process group.
|
355
|
+
"""
|
356
|
+
assert (
|
357
|
+
torch.distributed.is_initialized()
|
358
|
+
), "Default torch process group must be initialized"
|
359
|
+
assert group_name != "", "Group name cannot be empty"
|
360
|
+
|
361
|
+
rank = rank_offset + self.tp_rank
|
362
|
+
|
363
|
+
logger.info(
|
364
|
+
f"init custom process group: master_address={master_address}, master_port={master_port}, "
|
365
|
+
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
366
|
+
)
|
367
|
+
|
368
|
+
try:
|
369
|
+
self._model_update_group = init_custom_process_group(
|
370
|
+
backend=backend,
|
371
|
+
init_method=f"tcp://{master_address}:{master_port}",
|
372
|
+
world_size=world_size,
|
373
|
+
rank=rank,
|
374
|
+
group_name=group_name,
|
375
|
+
)
|
376
|
+
dist.barrier(group=self._model_update_group, device_ids=[rank])
|
377
|
+
return True, "Succeeded to initialize custom process group."
|
378
|
+
except Exception as e:
|
379
|
+
message = f"Failed to initialize custom process group: {e}."
|
380
|
+
logger.error(message)
|
381
|
+
return False, message
|
382
|
+
|
383
|
+
def update_weights_from_distributed(self, name, dtype, shape):
|
384
|
+
"""
|
385
|
+
Update specific parameter in the model weights online
|
386
|
+
through `_model_update_group` process group.
|
387
|
+
|
388
|
+
Args:
|
389
|
+
name: the name of the parameter to be updated.
|
390
|
+
dtype: the data type of the parameter to be updated.
|
391
|
+
shape: the shape of the parameter to be updated.
|
392
|
+
"""
|
393
|
+
target_dtype = (
|
394
|
+
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
395
|
+
)
|
396
|
+
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
397
|
+
|
398
|
+
assert (
|
399
|
+
self._model_update_group is not None
|
400
|
+
), "model update group must be initialized"
|
401
|
+
|
402
|
+
try:
|
403
|
+
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
|
404
|
+
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
|
405
|
+
self.model.load_weights([(name, weights)])
|
406
|
+
return True, f"Succeeded to update parameter {name} online."
|
407
|
+
|
408
|
+
except Exception as e:
|
409
|
+
error_msg = (
|
410
|
+
f"Failed to update parameter online: {e}. "
|
411
|
+
f"The full weights of the ModelRunner are partially updated. "
|
412
|
+
f"Please discard the whole weights."
|
413
|
+
)
|
414
|
+
logger.error(error_msg)
|
415
|
+
return False, error_msg
|
416
|
+
|
417
|
+
def get_weights_by_name(
|
418
|
+
self, name: str, truncate_size: int = 100
|
419
|
+
) -> Optional[torch.Tensor]:
|
420
|
+
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
421
|
+
|
422
|
+
Only used for unit test with an unoptimized performance.
|
423
|
+
For optimized performance, please use torch.save and torch.load.
|
424
|
+
"""
|
425
|
+
# TODO: (chenyang) Add support for Qwen models.
|
426
|
+
try:
|
427
|
+
return self.model.get_weights_by_name(
|
428
|
+
name, truncate_size, tp_size=self.tp_size
|
429
|
+
)
|
430
|
+
except Exception as e:
|
431
|
+
logger.error(f"Error when getting parameter {name}: {e}")
|
432
|
+
return None
|
433
|
+
|
400
434
|
def init_lora_manager(self):
|
401
435
|
self.lora_manager = LoRAManager(
|
402
436
|
base_model=self.model,
|
@@ -547,6 +581,8 @@ class ModelRunner:
|
|
547
581
|
self.attn_backend = DoubleSparseAttnBackend(self)
|
548
582
|
else:
|
549
583
|
self.attn_backend = TritonAttnBackend(self)
|
584
|
+
elif self.server_args.attention_backend == "torch_native":
|
585
|
+
self.attn_backend = TorchNativeAttnBackend(self)
|
550
586
|
else:
|
551
587
|
raise ValueError(
|
552
588
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -583,8 +619,10 @@ class ModelRunner:
|
|
583
619
|
if self.server_args.disable_cuda_graph:
|
584
620
|
return
|
585
621
|
|
622
|
+
tic = time.time()
|
586
623
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
587
624
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
625
|
+
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
588
626
|
|
589
627
|
def apply_torch_tp(self):
|
590
628
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -694,55 +732,3 @@ class ModelRunner:
|
|
694
732
|
if rope_scaling is None:
|
695
733
|
return False
|
696
734
|
return rope_scaling.get("type", None) == "mrope"
|
697
|
-
|
698
|
-
|
699
|
-
@lru_cache()
|
700
|
-
def import_model_classes():
|
701
|
-
model_arch_name_to_cls = {}
|
702
|
-
package_name = "sglang.srt.models"
|
703
|
-
package = importlib.import_module(package_name)
|
704
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
705
|
-
if not ispkg:
|
706
|
-
try:
|
707
|
-
module = importlib.import_module(name)
|
708
|
-
except Exception as e:
|
709
|
-
logger.warning(f"Ignore import error when loading {name}. {e}")
|
710
|
-
if crash_on_warnings():
|
711
|
-
raise ValueError(f"Ignore import error when loading {name}. {e}")
|
712
|
-
continue
|
713
|
-
if hasattr(module, "EntryClass"):
|
714
|
-
entry = module.EntryClass
|
715
|
-
if isinstance(
|
716
|
-
entry, list
|
717
|
-
): # To support multiple model classes in one module
|
718
|
-
for tmp in entry:
|
719
|
-
assert (
|
720
|
-
tmp.__name__ not in model_arch_name_to_cls
|
721
|
-
), f"Duplicated model implementation for {tmp.__name__}"
|
722
|
-
model_arch_name_to_cls[tmp.__name__] = tmp
|
723
|
-
else:
|
724
|
-
assert (
|
725
|
-
entry.__name__ not in model_arch_name_to_cls
|
726
|
-
), f"Duplicated model implementation for {entry.__name__}"
|
727
|
-
model_arch_name_to_cls[entry.__name__] = entry
|
728
|
-
|
729
|
-
return model_arch_name_to_cls
|
730
|
-
|
731
|
-
|
732
|
-
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
733
|
-
model_arch_name_to_cls = import_model_classes()
|
734
|
-
|
735
|
-
if model_arch not in model_arch_name_to_cls:
|
736
|
-
raise ValueError(
|
737
|
-
f"Unsupported architectures: {model_arch}. "
|
738
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
739
|
-
)
|
740
|
-
return model_arch_name_to_cls[model_arch]
|
741
|
-
|
742
|
-
|
743
|
-
# Monkey patch model loader
|
744
|
-
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
|
745
|
-
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
|
746
|
-
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
|
747
|
-
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
|
748
|
-
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
|
2
|
+
|
3
|
+
from torch import nn
|
4
|
+
|
5
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
6
|
+
from sglang.srt.configs.load_config import LoadConfig
|
7
|
+
from sglang.srt.configs.model_config import ModelConfig
|
8
|
+
from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
|
9
|
+
from sglang.srt.model_loader.utils import (
|
10
|
+
get_architecture_class_name,
|
11
|
+
get_model_architecture,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
def get_model(
|
16
|
+
*,
|
17
|
+
model_config: ModelConfig,
|
18
|
+
load_config: LoadConfig,
|
19
|
+
device_config: DeviceConfig,
|
20
|
+
) -> nn.Module:
|
21
|
+
loader = get_model_loader(load_config)
|
22
|
+
return loader.load_model(
|
23
|
+
model_config=model_config,
|
24
|
+
device_config=device_config,
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
"get_model",
|
30
|
+
"get_model_loader",
|
31
|
+
"BaseModelLoader",
|
32
|
+
"get_architecture_class_name",
|
33
|
+
"get_model_architecture",
|
34
|
+
]
|