sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -14,19 +14,13 @@
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
16
  import gc
17
- import importlib
18
- import importlib.resources
19
- import inspect
20
17
  import json
21
18
  import logging
22
- import pkgutil
23
- from functools import lru_cache
24
- from typing import Optional, Type
19
+ import time
20
+ from typing import Optional
25
21
 
26
22
  import torch
27
- import torch.nn as nn
28
- from vllm.config import DeviceConfig, LoadConfig
29
- from vllm.config import ModelConfig as VllmModelConfig
23
+ import torch.distributed as dist
30
24
  from vllm.distributed import (
31
25
  get_tp_group,
32
26
  init_distributed_environment,
@@ -34,12 +28,13 @@ from vllm.distributed import (
34
28
  set_custom_all_reduce,
35
29
  )
36
30
  from vllm.distributed.parallel_state import in_the_same_node_as
37
- from vllm.model_executor.model_loader import get_model
38
- from vllm.model_executor.models import ModelRegistry
39
31
 
32
+ from sglang.srt.configs.device_config import DeviceConfig
33
+ from sglang.srt.configs.load_config import LoadConfig
40
34
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
35
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
42
36
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
37
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
43
38
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
44
39
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
45
40
  from sglang.srt.layers.sampler import Sampler
@@ -52,14 +47,15 @@ from sglang.srt.mem_cache.memory_pool import (
52
47
  ReqToTokenPool,
53
48
  )
54
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader import get_model
55
51
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
56
52
  from sglang.srt.server_args import ServerArgs
57
53
  from sglang.srt.utils import (
58
- crash_on_warnings,
59
54
  enable_show_time_cost,
60
55
  get_available_gpu_memory,
56
+ init_custom_process_group,
61
57
  is_hip,
62
- monkey_patch_vllm_model_config,
58
+ monkey_patch_vllm_gguf_config,
63
59
  monkey_patch_vllm_p2p_access_check,
64
60
  set_cpu_offload_max_bytes,
65
61
  )
@@ -118,7 +114,7 @@ class ModelRunner:
118
114
  logger.info(
119
115
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
120
116
  )
121
- server_args.chunked_prefill_size = None
117
+ server_args.chunked_prefill_size = -1
122
118
  self.mem_fraction_static *= 0.95
123
119
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
124
120
  if self.model_config.hf_config.architectures == [
@@ -129,7 +125,7 @@ class ModelRunner:
129
125
  # Global vars
130
126
  if server_args.show_time_cost:
131
127
  enable_show_time_cost()
132
- if server_args.disable_disk_cache:
128
+ if server_args.disable_outlines_disk_cache:
133
129
  from outlines.caching import disable_cache
134
130
 
135
131
  disable_cache()
@@ -148,12 +144,14 @@ class ModelRunner:
148
144
 
149
145
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
150
146
 
151
- # Init components
147
+ # Get memory before model loading
152
148
  min_per_gpu_memory = self.init_torch_distributed()
149
+
150
+ # Load the model
153
151
  self.sampler = Sampler()
154
152
  self.load_model()
155
153
 
156
- # Apply torch TP if model supports it
154
+ # Apply torch TP if the model supports it
157
155
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
158
156
  if self.tp_size > 1 and supports_torch_tp:
159
157
  self.apply_torch_tp()
@@ -161,6 +159,7 @@ class ModelRunner:
161
159
  else:
162
160
  self.torch_tp_applied = False
163
161
 
162
+ # Init memory pool and attention backends
164
163
  if server_args.lora_paths is not None:
165
164
  self.init_lora_manager()
166
165
  self.init_memory_pool(
@@ -209,16 +208,6 @@ class ModelRunner:
209
208
  )
210
209
  self.tp_group = get_tp_group()
211
210
 
212
- # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
213
- # so we disable padding in cuda graph.
214
- if self.device == "cuda" and not all(
215
- in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
216
- ):
217
- self.server_args.disable_cuda_graph_padding = True
218
- logger.info(
219
- "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
220
- )
221
-
222
211
  # Check memory for tensor parallelism
223
212
  if self.tp_size > 1:
224
213
  local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
@@ -229,49 +218,6 @@ class ModelRunner:
229
218
 
230
219
  return min_per_gpu_memory
231
220
 
232
- def setup_model(self):
233
- try:
234
- from vllm.config import VllmConfig
235
-
236
- vllm_config = VllmConfig()
237
- vllm_config.model_config = self.vllm_model_config
238
- vllm_config.load_config = self.load_config
239
- vllm_config.device_config = DeviceConfig(self.device)
240
- vllm_config.quant_config = VllmConfig._get_quantization_config(
241
- vllm_config.model_config, vllm_config.load_config
242
- )
243
- return get_model(vllm_config=vllm_config)
244
- except ImportError:
245
- pass
246
-
247
- return get_model(
248
- model_config=self.vllm_model_config,
249
- load_config=self.load_config,
250
- device_config=DeviceConfig(self.device),
251
- parallel_config=None,
252
- scheduler_config=None,
253
- lora_config=None,
254
- cache_config=None,
255
- )
256
-
257
- def get_model_config_params(self):
258
- sig = inspect.signature(VllmModelConfig.__init__)
259
- params = {
260
- "model": self.server_args.model_path,
261
- "quantization": self.server_args.quantization,
262
- "tokenizer": None,
263
- "tokenizer_mode": None,
264
- "trust_remote_code": self.server_args.trust_remote_code,
265
- "dtype": self.server_args.dtype,
266
- "seed": self.server_args.random_seed,
267
- "skip_tokenizer_init": True,
268
- }
269
-
270
- if "task" in sig.parameters:
271
- params["task"] = ""
272
-
273
- return params
274
-
275
221
  def load_model(self):
276
222
  logger.info(
277
223
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -285,6 +231,7 @@ class ModelRunner:
285
231
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
286
232
  )
287
233
  self.server_args.dtype = "float16"
234
+ self.model_config.dtype = torch.float16
288
235
  if torch.cuda.get_device_capability()[1] < 5:
289
236
  raise RuntimeError("SGLang only supports sm75 and above.")
290
237
 
@@ -293,21 +240,21 @@ class ModelRunner:
293
240
  load_format=self.server_args.load_format,
294
241
  download_dir=self.server_args.download_dir,
295
242
  )
296
- monkey_patch_vllm_model_config()
297
- self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
298
- if self.model_config.model_override_args is not None:
299
- self.vllm_model_config.hf_config.update(
300
- self.model_config.model_override_args
301
- )
302
243
 
303
- self.model = self.setup_model()
244
+ if self.server_args.load_format == "gguf":
245
+ monkey_patch_vllm_gguf_config()
246
+ self.model = get_model(
247
+ model_config=self.model_config,
248
+ load_config=self.load_config,
249
+ device_config=DeviceConfig(self.device),
250
+ )
304
251
 
305
252
  self.sliding_window_size = (
306
253
  self.model.get_attention_sliding_window_size()
307
254
  if hasattr(self.model, "get_attention_sliding_window_size")
308
255
  else None
309
256
  )
310
- self.dtype = self.vllm_model_config.dtype
257
+ self.dtype = self.model_config.dtype
311
258
 
312
259
  logger.info(
313
260
  f"Load weight end. "
@@ -316,30 +263,22 @@ class ModelRunner:
316
263
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
317
264
  )
318
265
 
319
- def update_weights(self, model_path: str, load_format: str):
320
- """Update weights in-place."""
321
- from vllm.model_executor.model_loader.loader import (
266
+ def update_weights_from_disk(self, model_path: str, load_format: str):
267
+ """Update engine weights online from disk."""
268
+ from sglang.srt.model_loader.loader import (
322
269
  DefaultModelLoader,
323
270
  device_loading_context,
324
271
  get_model_loader,
325
272
  )
326
- from vllm.model_executor.model_loader.utils import set_default_torch_dtype
273
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
327
274
 
328
275
  logger.info(
329
- f"Update weights begin. "
276
+ f"Update engine weights online from disk begin. "
330
277
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
331
278
  )
332
279
 
333
280
  target_device = torch.device(self.device)
334
-
335
- try:
336
- model_config_params = self.get_model_config_params()
337
- model_config_params["model"] = model_path
338
- vllm_model_config = VllmModelConfig(**model_config_params)
339
- except Exception as e:
340
- message = f"Failed to load model config: {e}."
341
- return False, message
342
-
281
+ self.model_config.model_path = model_path
343
282
  load_config = LoadConfig(load_format=load_format)
344
283
 
345
284
  # Only support vllm DefaultModelLoader for now
@@ -351,7 +290,7 @@ class ModelRunner:
351
290
  def get_weight_iter(config):
352
291
  iter = loader._get_weights_iterator(
353
292
  DefaultModelLoader.Source(
354
- config.model,
293
+ config.model_path,
355
294
  revision=config.revision,
356
295
  fall_back_to_pt=getattr(
357
296
  self.model, "fall_back_to_pt_during_load", True
@@ -369,9 +308,9 @@ class ModelRunner:
369
308
  quant_method.process_weights_after_loading(module)
370
309
  return model
371
310
 
372
- with set_default_torch_dtype(vllm_model_config.dtype):
311
+ with set_default_torch_dtype(self.model_config.dtype):
373
312
  try:
374
- iter = get_weight_iter(vllm_model_config)
313
+ iter = get_weight_iter(self.model_config)
375
314
  except Exception as e:
376
315
  message = f"Failed to get weights iterator: {e}."
377
316
  return False, message
@@ -383,20 +322,115 @@ class ModelRunner:
383
322
  )
384
323
  del iter
385
324
  gc.collect()
386
- iter = get_weight_iter(self.vllm_model_config)
325
+ iter = get_weight_iter(self.model_config)
387
326
  self.model = model_load_weights(self.model, iter)
388
327
  return False, message
389
328
 
390
329
  self.model = model
391
330
  self.server_args.model_path = model_path
392
331
  self.server_args.load_format = load_format
393
- self.vllm_model_config = vllm_model_config
394
332
  self.load_config = load_config
395
- self.model_config.path = model_path
396
333
 
397
334
  logger.info("Update weights end.")
398
335
  return True, "Succeeded to update model weights."
399
336
 
337
+ def init_weights_update_group(
338
+ self,
339
+ master_address,
340
+ master_port,
341
+ rank_offset,
342
+ world_size,
343
+ group_name,
344
+ backend="nccl",
345
+ ):
346
+ """Initialize the Torch process group for model parameter updates.
347
+
348
+ `_model_update_group` is used in the RLHF workflow, where rank
349
+ 0 is the actor model in the training engine, and the other ranks are
350
+ the inference engine, which is used for rollout.
351
+
352
+ In the RLHF workflow, the training engine updates the model
353
+ weights/parameters online, and broadcasts them to the inference
354
+ engine through the `_model_update_group` process group.
355
+ """
356
+ assert (
357
+ torch.distributed.is_initialized()
358
+ ), "Default torch process group must be initialized"
359
+ assert group_name != "", "Group name cannot be empty"
360
+
361
+ rank = rank_offset + self.tp_rank
362
+
363
+ logger.info(
364
+ f"init custom process group: master_address={master_address}, master_port={master_port}, "
365
+ f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
366
+ )
367
+
368
+ try:
369
+ self._model_update_group = init_custom_process_group(
370
+ backend=backend,
371
+ init_method=f"tcp://{master_address}:{master_port}",
372
+ world_size=world_size,
373
+ rank=rank,
374
+ group_name=group_name,
375
+ )
376
+ dist.barrier(group=self._model_update_group, device_ids=[rank])
377
+ return True, "Succeeded to initialize custom process group."
378
+ except Exception as e:
379
+ message = f"Failed to initialize custom process group: {e}."
380
+ logger.error(message)
381
+ return False, message
382
+
383
+ def update_weights_from_distributed(self, name, dtype, shape):
384
+ """
385
+ Update specific parameter in the model weights online
386
+ through `_model_update_group` process group.
387
+
388
+ Args:
389
+ name: the name of the parameter to be updated.
390
+ dtype: the data type of the parameter to be updated.
391
+ shape: the shape of the parameter to be updated.
392
+ """
393
+ target_dtype = (
394
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
395
+ )
396
+ current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
397
+
398
+ assert (
399
+ self._model_update_group is not None
400
+ ), "model update group must be initialized"
401
+
402
+ try:
403
+ weights = torch.empty(shape, dtype=target_dtype, device=self.device)
404
+ torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
405
+ self.model.load_weights([(name, weights)])
406
+ return True, f"Succeeded to update parameter {name} online."
407
+
408
+ except Exception as e:
409
+ error_msg = (
410
+ f"Failed to update parameter online: {e}. "
411
+ f"The full weights of the ModelRunner are partially updated. "
412
+ f"Please discard the whole weights."
413
+ )
414
+ logger.error(error_msg)
415
+ return False, error_msg
416
+
417
+ def get_weights_by_name(
418
+ self, name: str, truncate_size: int = 100
419
+ ) -> Optional[torch.Tensor]:
420
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
421
+
422
+ Only used for unit test with an unoptimized performance.
423
+ For optimized performance, please use torch.save and torch.load.
424
+ """
425
+ # TODO: (chenyang) Add support for Qwen models.
426
+ try:
427
+ return self.model.get_weights_by_name(
428
+ name, truncate_size, tp_size=self.tp_size
429
+ )
430
+ except Exception as e:
431
+ logger.error(f"Error when getting parameter {name}: {e}")
432
+ return None
433
+
400
434
  def init_lora_manager(self):
401
435
  self.lora_manager = LoRAManager(
402
436
  base_model=self.model,
@@ -547,6 +581,8 @@ class ModelRunner:
547
581
  self.attn_backend = DoubleSparseAttnBackend(self)
548
582
  else:
549
583
  self.attn_backend = TritonAttnBackend(self)
584
+ elif self.server_args.attention_backend == "torch_native":
585
+ self.attn_backend = TorchNativeAttnBackend(self)
550
586
  else:
551
587
  raise ValueError(
552
588
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -583,8 +619,10 @@ class ModelRunner:
583
619
  if self.server_args.disable_cuda_graph:
584
620
  return
585
621
 
622
+ tic = time.time()
586
623
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
587
624
  self.cuda_graph_runner = CudaGraphRunner(self)
625
+ logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
588
626
 
589
627
  def apply_torch_tp(self):
590
628
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -694,55 +732,3 @@ class ModelRunner:
694
732
  if rope_scaling is None:
695
733
  return False
696
734
  return rope_scaling.get("type", None) == "mrope"
697
-
698
-
699
- @lru_cache()
700
- def import_model_classes():
701
- model_arch_name_to_cls = {}
702
- package_name = "sglang.srt.models"
703
- package = importlib.import_module(package_name)
704
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
705
- if not ispkg:
706
- try:
707
- module = importlib.import_module(name)
708
- except Exception as e:
709
- logger.warning(f"Ignore import error when loading {name}. {e}")
710
- if crash_on_warnings():
711
- raise ValueError(f"Ignore import error when loading {name}. {e}")
712
- continue
713
- if hasattr(module, "EntryClass"):
714
- entry = module.EntryClass
715
- if isinstance(
716
- entry, list
717
- ): # To support multiple model classes in one module
718
- for tmp in entry:
719
- assert (
720
- tmp.__name__ not in model_arch_name_to_cls
721
- ), f"Duplicated model implementation for {tmp.__name__}"
722
- model_arch_name_to_cls[tmp.__name__] = tmp
723
- else:
724
- assert (
725
- entry.__name__ not in model_arch_name_to_cls
726
- ), f"Duplicated model implementation for {entry.__name__}"
727
- model_arch_name_to_cls[entry.__name__] = entry
728
-
729
- return model_arch_name_to_cls
730
-
731
-
732
- def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
733
- model_arch_name_to_cls = import_model_classes()
734
-
735
- if model_arch not in model_arch_name_to_cls:
736
- raise ValueError(
737
- f"Unsupported architectures: {model_arch}. "
738
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
739
- )
740
- return model_arch_name_to_cls[model_arch]
741
-
742
-
743
- # Monkey patch model loader
744
- setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
745
- setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
746
- setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
747
- setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
748
- setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
@@ -0,0 +1,34 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
2
+
3
+ from torch import nn
4
+
5
+ from sglang.srt.configs.device_config import DeviceConfig
6
+ from sglang.srt.configs.load_config import LoadConfig
7
+ from sglang.srt.configs.model_config import ModelConfig
8
+ from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
9
+ from sglang.srt.model_loader.utils import (
10
+ get_architecture_class_name,
11
+ get_model_architecture,
12
+ )
13
+
14
+
15
+ def get_model(
16
+ *,
17
+ model_config: ModelConfig,
18
+ load_config: LoadConfig,
19
+ device_config: DeviceConfig,
20
+ ) -> nn.Module:
21
+ loader = get_model_loader(load_config)
22
+ return loader.load_model(
23
+ model_config=model_config,
24
+ device_config=device_config,
25
+ )
26
+
27
+
28
+ __all__ = [
29
+ "get_model",
30
+ "get_model_loader",
31
+ "BaseModelLoader",
32
+ "get_architecture_class_name",
33
+ "get_model_architecture",
34
+ ]