sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -14,35 +14,30 @@
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
16
  import gc
17
- import importlib
18
- import importlib.resources
19
- import inspect
20
17
  import json
21
18
  import logging
22
- import pkgutil
23
- from functools import lru_cache
24
- from typing import Optional, Type
19
+ import time
20
+ from typing import Optional
25
21
 
26
22
  import torch
27
- import torch.nn as nn
28
- from vllm.config import DeviceConfig, LoadConfig
29
- from vllm.config import ModelConfig as VllmModelConfig
23
+ import torch.distributed as dist
30
24
  from vllm.distributed import (
31
25
  get_tp_group,
32
26
  init_distributed_environment,
33
27
  initialize_model_parallel,
34
28
  set_custom_all_reduce,
35
29
  )
36
- from vllm.distributed.parallel_state import in_the_same_node_as
37
- from vllm.model_executor.model_loader import get_model
38
- from vllm.model_executor.models import ModelRegistry
39
30
 
31
+ from sglang.srt.configs.device_config import DeviceConfig
32
+ from sglang.srt.configs.load_config import LoadConfig
40
33
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
34
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
42
35
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
+ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
43
37
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
44
38
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
45
39
  from sglang.srt.layers.sampler import Sampler
40
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
46
41
  from sglang.srt.lora.lora_manager import LoRAManager
47
42
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
43
  from sglang.srt.mem_cache.memory_pool import (
@@ -52,14 +47,15 @@ from sglang.srt.mem_cache.memory_pool import (
52
47
  ReqToTokenPool,
53
48
  )
54
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader import get_model
55
51
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
56
52
  from sglang.srt.server_args import ServerArgs
57
53
  from sglang.srt.utils import (
58
- crash_on_warnings,
59
54
  enable_show_time_cost,
60
55
  get_available_gpu_memory,
56
+ init_custom_process_group,
61
57
  is_hip,
62
- monkey_patch_vllm_model_config,
58
+ monkey_patch_vllm_gguf_config,
63
59
  monkey_patch_vllm_p2p_access_check,
64
60
  set_cpu_offload_max_bytes,
65
61
  )
@@ -115,11 +111,13 @@ class ModelRunner:
115
111
  )
116
112
 
117
113
  if self.is_multimodal:
114
+ server_args.chunked_prefill_size = -1
115
+ self.mem_fraction_static *= 0.95
118
116
  logger.info(
119
- "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
117
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
118
+ f"and turn off chunked prefill "
119
+ f"because this is a multimodal model."
120
120
  )
121
- server_args.chunked_prefill_size = None
122
- self.mem_fraction_static *= 0.95
123
121
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
124
122
  if self.model_config.hf_config.architectures == [
125
123
  "Qwen2VLForConditionalGeneration"
@@ -129,7 +127,7 @@ class ModelRunner:
129
127
  # Global vars
130
128
  if server_args.show_time_cost:
131
129
  enable_show_time_cost()
132
- if server_args.disable_disk_cache:
130
+ if server_args.disable_outlines_disk_cache:
133
131
  from outlines.caching import disable_cache
134
132
 
135
133
  disable_cache()
@@ -143,17 +141,20 @@ class ModelRunner:
143
141
  "torchao_config": server_args.torchao_config,
144
142
  "enable_nan_detection": server_args.enable_nan_detection,
145
143
  "enable_dp_attention": server_args.enable_dp_attention,
144
+ "enable_ep_moe": server_args.enable_ep_moe,
146
145
  }
147
146
  )
148
147
 
149
148
  set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
150
149
 
151
- # Init components
150
+ # Get memory before model loading
152
151
  min_per_gpu_memory = self.init_torch_distributed()
152
+
153
+ # Load the model
153
154
  self.sampler = Sampler()
154
155
  self.load_model()
155
156
 
156
- # Apply torch TP if model supports it
157
+ # Apply torch TP if the model supports it
157
158
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
158
159
  if self.tp_size > 1 and supports_torch_tp:
159
160
  self.apply_torch_tp()
@@ -161,6 +162,11 @@ class ModelRunner:
161
162
  else:
162
163
  self.torch_tp_applied = False
163
164
 
165
+ apply_torchao_config_to_model(
166
+ self.model, global_server_args_dict["torchao_config"]
167
+ )
168
+
169
+ # Init memory pool and attention backends
164
170
  if server_args.lora_paths is not None:
165
171
  self.init_lora_manager()
166
172
  self.init_memory_pool(
@@ -209,16 +215,6 @@ class ModelRunner:
209
215
  )
210
216
  self.tp_group = get_tp_group()
211
217
 
212
- # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
213
- # so we disable padding in cuda graph.
214
- if self.device == "cuda" and not all(
215
- in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
216
- ):
217
- self.server_args.disable_cuda_graph_padding = True
218
- logger.info(
219
- "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
220
- )
221
-
222
218
  # Check memory for tensor parallelism
223
219
  if self.tp_size > 1:
224
220
  local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
@@ -229,49 +225,6 @@ class ModelRunner:
229
225
 
230
226
  return min_per_gpu_memory
231
227
 
232
- def setup_model(self):
233
- try:
234
- from vllm.config import VllmConfig
235
-
236
- vllm_config = VllmConfig()
237
- vllm_config.model_config = self.vllm_model_config
238
- vllm_config.load_config = self.load_config
239
- vllm_config.device_config = DeviceConfig(self.device)
240
- vllm_config.quant_config = VllmConfig._get_quantization_config(
241
- vllm_config.model_config, vllm_config.load_config
242
- )
243
- return get_model(vllm_config=vllm_config)
244
- except ImportError:
245
- pass
246
-
247
- return get_model(
248
- model_config=self.vllm_model_config,
249
- load_config=self.load_config,
250
- device_config=DeviceConfig(self.device),
251
- parallel_config=None,
252
- scheduler_config=None,
253
- lora_config=None,
254
- cache_config=None,
255
- )
256
-
257
- def get_model_config_params(self):
258
- sig = inspect.signature(VllmModelConfig.__init__)
259
- params = {
260
- "model": self.server_args.model_path,
261
- "quantization": self.server_args.quantization,
262
- "tokenizer": None,
263
- "tokenizer_mode": None,
264
- "trust_remote_code": self.server_args.trust_remote_code,
265
- "dtype": self.server_args.dtype,
266
- "seed": self.server_args.random_seed,
267
- "skip_tokenizer_init": True,
268
- }
269
-
270
- if "task" in sig.parameters:
271
- params["task"] = ""
272
-
273
- return params
274
-
275
228
  def load_model(self):
276
229
  logger.info(
277
230
  f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
@@ -285,6 +238,7 @@ class ModelRunner:
285
238
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
286
239
  )
287
240
  self.server_args.dtype = "float16"
241
+ self.model_config.dtype = torch.float16
288
242
  if torch.cuda.get_device_capability()[1] < 5:
289
243
  raise RuntimeError("SGLang only supports sm75 and above.")
290
244
 
@@ -293,21 +247,21 @@ class ModelRunner:
293
247
  load_format=self.server_args.load_format,
294
248
  download_dir=self.server_args.download_dir,
295
249
  )
296
- monkey_patch_vllm_model_config()
297
- self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
298
- if self.model_config.model_override_args is not None:
299
- self.vllm_model_config.hf_config.update(
300
- self.model_config.model_override_args
301
- )
302
250
 
303
- self.model = self.setup_model()
251
+ if self.server_args.load_format == "gguf":
252
+ monkey_patch_vllm_gguf_config()
253
+ self.model = get_model(
254
+ model_config=self.model_config,
255
+ load_config=self.load_config,
256
+ device_config=DeviceConfig(self.device),
257
+ )
304
258
 
305
259
  self.sliding_window_size = (
306
260
  self.model.get_attention_sliding_window_size()
307
261
  if hasattr(self.model, "get_attention_sliding_window_size")
308
262
  else None
309
263
  )
310
- self.dtype = self.vllm_model_config.dtype
264
+ self.dtype = self.model_config.dtype
311
265
 
312
266
  logger.info(
313
267
  f"Load weight end. "
@@ -316,30 +270,22 @@ class ModelRunner:
316
270
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
317
271
  )
318
272
 
319
- def update_weights(self, model_path: str, load_format: str):
320
- """Update weights in-place."""
321
- from vllm.model_executor.model_loader.loader import (
273
+ def update_weights_from_disk(self, model_path: str, load_format: str):
274
+ """Update engine weights online from disk."""
275
+ from sglang.srt.model_loader.loader import (
322
276
  DefaultModelLoader,
323
277
  device_loading_context,
324
278
  get_model_loader,
325
279
  )
326
- from vllm.model_executor.model_loader.utils import set_default_torch_dtype
280
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
327
281
 
328
282
  logger.info(
329
- f"Update weights begin. "
283
+ f"Update engine weights online from disk begin. "
330
284
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
331
285
  )
332
286
 
333
287
  target_device = torch.device(self.device)
334
-
335
- try:
336
- model_config_params = self.get_model_config_params()
337
- model_config_params["model"] = model_path
338
- vllm_model_config = VllmModelConfig(**model_config_params)
339
- except Exception as e:
340
- message = f"Failed to load model config: {e}."
341
- return False, message
342
-
288
+ self.model_config.model_path = model_path
343
289
  load_config = LoadConfig(load_format=load_format)
344
290
 
345
291
  # Only support vllm DefaultModelLoader for now
@@ -351,7 +297,7 @@ class ModelRunner:
351
297
  def get_weight_iter(config):
352
298
  iter = loader._get_weights_iterator(
353
299
  DefaultModelLoader.Source(
354
- config.model,
300
+ config.model_path,
355
301
  revision=config.revision,
356
302
  fall_back_to_pt=getattr(
357
303
  self.model, "fall_back_to_pt_during_load", True
@@ -369,9 +315,9 @@ class ModelRunner:
369
315
  quant_method.process_weights_after_loading(module)
370
316
  return model
371
317
 
372
- with set_default_torch_dtype(vllm_model_config.dtype):
318
+ with set_default_torch_dtype(self.model_config.dtype):
373
319
  try:
374
- iter = get_weight_iter(vllm_model_config)
320
+ iter = get_weight_iter(self.model_config)
375
321
  except Exception as e:
376
322
  message = f"Failed to get weights iterator: {e}."
377
323
  return False, message
@@ -383,20 +329,115 @@ class ModelRunner:
383
329
  )
384
330
  del iter
385
331
  gc.collect()
386
- iter = get_weight_iter(self.vllm_model_config)
332
+ iter = get_weight_iter(self.model_config)
387
333
  self.model = model_load_weights(self.model, iter)
388
334
  return False, message
389
335
 
390
336
  self.model = model
391
337
  self.server_args.model_path = model_path
392
338
  self.server_args.load_format = load_format
393
- self.vllm_model_config = vllm_model_config
394
339
  self.load_config = load_config
395
- self.model_config.path = model_path
396
340
 
397
341
  logger.info("Update weights end.")
398
342
  return True, "Succeeded to update model weights."
399
343
 
344
+ def init_weights_update_group(
345
+ self,
346
+ master_address,
347
+ master_port,
348
+ rank_offset,
349
+ world_size,
350
+ group_name,
351
+ backend="nccl",
352
+ ):
353
+ """Initialize the Torch process group for model parameter updates.
354
+
355
+ `_model_update_group` is used in the RLHF workflow, where rank
356
+ 0 is the actor model in the training engine, and the other ranks are
357
+ the inference engine, which is used for rollout.
358
+
359
+ In the RLHF workflow, the training engine updates the model
360
+ weights/parameters online, and broadcasts them to the inference
361
+ engine through the `_model_update_group` process group.
362
+ """
363
+ assert (
364
+ torch.distributed.is_initialized()
365
+ ), "Default torch process group must be initialized"
366
+ assert group_name != "", "Group name cannot be empty"
367
+
368
+ rank = rank_offset + self.tp_rank
369
+
370
+ logger.info(
371
+ f"init custom process group: master_address={master_address}, master_port={master_port}, "
372
+ f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
373
+ )
374
+
375
+ try:
376
+ self._model_update_group = init_custom_process_group(
377
+ backend=backend,
378
+ init_method=f"tcp://{master_address}:{master_port}",
379
+ world_size=world_size,
380
+ rank=rank,
381
+ group_name=group_name,
382
+ )
383
+ dist.barrier(group=self._model_update_group, device_ids=[rank])
384
+ return True, "Succeeded to initialize custom process group."
385
+ except Exception as e:
386
+ message = f"Failed to initialize custom process group: {e}."
387
+ logger.error(message)
388
+ return False, message
389
+
390
+ def update_weights_from_distributed(self, name, dtype, shape):
391
+ """
392
+ Update specific parameter in the model weights online
393
+ through `_model_update_group` process group.
394
+
395
+ Args:
396
+ name: the name of the parameter to be updated.
397
+ dtype: the data type of the parameter to be updated.
398
+ shape: the shape of the parameter to be updated.
399
+ """
400
+ target_dtype = (
401
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
402
+ )
403
+ current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
404
+
405
+ assert (
406
+ self._model_update_group is not None
407
+ ), "model update group must be initialized"
408
+
409
+ try:
410
+ weights = torch.empty(shape, dtype=target_dtype, device=self.device)
411
+ torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
412
+ self.model.load_weights([(name, weights)])
413
+ return True, f"Succeeded to update parameter {name} online."
414
+
415
+ except Exception as e:
416
+ error_msg = (
417
+ f"Failed to update parameter online: {e}. "
418
+ f"The full weights of the ModelRunner are partially updated. "
419
+ f"Please discard the whole weights."
420
+ )
421
+ logger.error(error_msg)
422
+ return False, error_msg
423
+
424
+ def get_weights_by_name(
425
+ self, name: str, truncate_size: int = 100
426
+ ) -> Optional[torch.Tensor]:
427
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
428
+
429
+ Only used for unit test with an unoptimized performance.
430
+ For optimized performance, please use torch.save and torch.load.
431
+ """
432
+ # TODO: (chenyang) Add support for Qwen models.
433
+ try:
434
+ return self.model.get_weights_by_name(
435
+ name, truncate_size, tp_size=self.tp_size
436
+ )
437
+ except Exception as e:
438
+ logger.error(f"Error when getting parameter {name}: {e}")
439
+ return None
440
+
400
441
  def init_lora_manager(self):
401
442
  self.lora_manager = LoRAManager(
402
443
  base_model=self.model,
@@ -547,6 +588,8 @@ class ModelRunner:
547
588
  self.attn_backend = DoubleSparseAttnBackend(self)
548
589
  else:
549
590
  self.attn_backend = TritonAttnBackend(self)
591
+ elif self.server_args.attention_backend == "torch_native":
592
+ self.attn_backend = TorchNativeAttnBackend(self)
550
593
  else:
551
594
  raise ValueError(
552
595
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -583,8 +626,10 @@ class ModelRunner:
583
626
  if self.server_args.disable_cuda_graph:
584
627
  return
585
628
 
629
+ tic = time.time()
586
630
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
587
631
  self.cuda_graph_runner = CudaGraphRunner(self)
632
+ logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
588
633
 
589
634
  def apply_torch_tp(self):
590
635
  logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
@@ -694,55 +739,3 @@ class ModelRunner:
694
739
  if rope_scaling is None:
695
740
  return False
696
741
  return rope_scaling.get("type", None) == "mrope"
697
-
698
-
699
- @lru_cache()
700
- def import_model_classes():
701
- model_arch_name_to_cls = {}
702
- package_name = "sglang.srt.models"
703
- package = importlib.import_module(package_name)
704
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
705
- if not ispkg:
706
- try:
707
- module = importlib.import_module(name)
708
- except Exception as e:
709
- logger.warning(f"Ignore import error when loading {name}. {e}")
710
- if crash_on_warnings():
711
- raise ValueError(f"Ignore import error when loading {name}. {e}")
712
- continue
713
- if hasattr(module, "EntryClass"):
714
- entry = module.EntryClass
715
- if isinstance(
716
- entry, list
717
- ): # To support multiple model classes in one module
718
- for tmp in entry:
719
- assert (
720
- tmp.__name__ not in model_arch_name_to_cls
721
- ), f"Duplicated model implementation for {tmp.__name__}"
722
- model_arch_name_to_cls[tmp.__name__] = tmp
723
- else:
724
- assert (
725
- entry.__name__ not in model_arch_name_to_cls
726
- ), f"Duplicated model implementation for {entry.__name__}"
727
- model_arch_name_to_cls[entry.__name__] = entry
728
-
729
- return model_arch_name_to_cls
730
-
731
-
732
- def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
733
- model_arch_name_to_cls = import_model_classes()
734
-
735
- if model_arch not in model_arch_name_to_cls:
736
- raise ValueError(
737
- f"Unsupported architectures: {model_arch}. "
738
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
739
- )
740
- return model_arch_name_to_cls[model_arch]
741
-
742
-
743
- # Monkey patch model loader
744
- setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
745
- setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
746
- setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
747
- setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
748
- setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
@@ -0,0 +1,34 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/__init__.py
2
+
3
+ from torch import nn
4
+
5
+ from sglang.srt.configs.device_config import DeviceConfig
6
+ from sglang.srt.configs.load_config import LoadConfig
7
+ from sglang.srt.configs.model_config import ModelConfig
8
+ from sglang.srt.model_loader.loader import BaseModelLoader, get_model_loader
9
+ from sglang.srt.model_loader.utils import (
10
+ get_architecture_class_name,
11
+ get_model_architecture,
12
+ )
13
+
14
+
15
+ def get_model(
16
+ *,
17
+ model_config: ModelConfig,
18
+ load_config: LoadConfig,
19
+ device_config: DeviceConfig,
20
+ ) -> nn.Module:
21
+ loader = get_model_loader(load_config)
22
+ return loader.load_model(
23
+ model_config=model_config,
24
+ device_config=device_config,
25
+ )
26
+
27
+
28
+ __all__ = [
29
+ "get_model",
30
+ "get_model_loader",
31
+ "BaseModelLoader",
32
+ "get_architecture_class_name",
33
+ "get_model_architecture",
34
+ ]