sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import MixtralConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
26
 
28
27
  from sglang.srt.layers.fused_moe_triton import FusedMoE
29
28
  from sglang.srt.layers.layernorm import RMSNorm
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
41
  )
43
42
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
 
47
47
  class MixtralMoE(nn.Module):
@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
291
291
  self,
292
292
  config: MixtralConfig,
293
293
  quant_config: Optional[QuantizationConfig] = None,
294
- cache_config=None,
295
294
  ) -> None:
296
295
  super().__init__()
297
296
  self.config = config
@@ -310,7 +309,7 @@ class MixtralForCausalLM(nn.Module):
310
309
  ) -> torch.Tensor:
311
310
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
312
311
  return self.logits_processor(
313
- input_ids, hidden_states, self.lm_head.weight, forward_batch
312
+ input_ids, hidden_states, self.lm_head, forward_batch
314
313
  )
315
314
 
316
315
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -340,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
340
339
  continue
341
340
  name = name.replace(weight_name, param_name)
342
341
  # Skip loading extra bias for GPTQ models.
343
- if name.endswith(".bias") and name not in params_dict:
342
+ if (
343
+ name.endswith(".bias") or name.endswith("_bias")
344
+ ) and name not in params_dict:
344
345
  continue
345
346
 
346
347
  param = params_dict[name]
@@ -354,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
354
355
  continue
355
356
  name = name.replace(weight_name, param_name)
356
357
 
358
+ if (
359
+ name.endswith(".bias") or name.endswith("_bias")
360
+ ) and name not in params_dict:
361
+ continue
357
362
  param = params_dict[name]
358
363
  weight_loader = param.weight_loader
359
364
  weight_loader(
@@ -366,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
366
371
  break
367
372
  else:
368
373
  # Skip loading extra bias for GPTQ models.
369
- if name.endswith(".bias") and name not in params_dict:
374
+ if (
375
+ name.endswith(".bias") or name.endswith("_bias")
376
+ ) and name not in params_dict:
370
377
  continue
371
378
  # Skip loading kv_scale from ckpts towards new design.
372
379
  if name.endswith(".kv_scale") and name not in params_dict:
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
32
 
34
33
  from sglang.srt.layers.layernorm import RMSNorm
35
34
  from sglang.srt.layers.linear import (
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
44
  VocabParallelEmbedding,
46
45
  )
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
 
49
49
 
50
50
  class MixtralMLP(nn.Module):
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
324
324
  self,
325
325
  config: MixtralConfig,
326
326
  quant_config: Optional[QuantizationConfig] = None,
327
- cache_config=None,
328
327
  ) -> None:
329
328
  super().__init__()
330
329
  self.config = config
@@ -343,7 +342,7 @@ class QuantMixtralForCausalLM(nn.Module):
343
342
  ) -> torch.Tensor:
344
343
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
344
  return self.logits_processor(
346
- input_ids, hidden_states, self.lm_head.weight, forward_batch
345
+ input_ids, hidden_states, self.lm_head, forward_batch
347
346
  )
348
347
 
349
348
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
15
15
  _prepare_aspect_ratio_attention_mask,
16
16
  )
17
17
  from vllm.distributed import get_tensor_model_parallel_world_size
18
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19
18
 
20
19
  from sglang.srt.layers.activation import get_act_fn
21
20
  from sglang.srt.layers.layernorm import RMSNorm
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
34
33
  )
35
34
  from sglang.srt.managers.schedule_batch import ImageInputs
36
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
37
37
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
38
38
 
39
39
 
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
654
654
  self,
655
655
  config: config_mllama.MllamaTextConfig,
656
656
  quant_config: Optional[QuantizationConfig],
657
- cache_config=None,
658
657
  ):
659
658
  super().__init__()
660
659
  self.padding_id = config.pad_token_id
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
732
731
  self,
733
732
  config: config_mllama.MllamaTextConfig,
734
733
  quant_config: Optional[QuantizationConfig],
735
- cache_config=None,
736
734
  ):
737
735
  super().__init__()
738
736
  self.vocab_size = config.vocab_size
739
- self.model = MllamaTextModel(config, cache_config, quant_config)
737
+ self.model = MllamaTextModel(config, quant_config)
740
738
  self.lm_head = ParallelLMHead(
741
739
  config.vocab_size,
742
740
  config.hidden_size,
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
772
770
  self,
773
771
  config: config_mllama.MllamaConfig,
774
772
  quant_config: Optional[QuantizationConfig] = None,
775
- cache_config=None,
776
773
  ):
777
774
  super().__init__()
778
775
  self.vocab_size = config.text_config.vocab_size
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
787
784
  self.vision_model = MllamaVisionModel(config.vision_config)
788
785
  self.language_model = MllamaForCausalLM(
789
786
  config.text_config,
790
- cache_config=cache_config,
791
787
  quant_config=quant_config,
792
788
  )
793
789
  self.multi_modal_projector = nn.Linear(
@@ -966,7 +962,7 @@ class MllamaForConditionalGeneration(nn.Module):
966
962
  skip_cross_attention=skip_cross_attention,
967
963
  )
968
964
  return self.logits_processor(
969
- input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
965
+ input_ids, hidden_states, self.language_model.lm_head, forward_batch
970
966
  )
971
967
 
972
968
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/olmo.py CHANGED
@@ -22,7 +22,6 @@ from torch import nn
22
22
  from transformers import OlmoConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.linear import (
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
37
  VocabParallelEmbedding,
39
38
  )
40
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.utils import make_layers
42
42
 
43
43
 
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
274
274
  def __init__(
275
275
  self,
276
276
  config: OlmoConfig,
277
- cache_config=None,
278
277
  quant_config: Optional[QuantizationConfig] = None,
279
278
  ):
280
279
  super().__init__()
@@ -306,7 +305,7 @@ class OlmoForCausalLM(nn.Module):
306
305
  input_embeds=input_embeds,
307
306
  )
308
307
  return self.logits_processor(
309
- input_ids, hidden_states, self.lm_head.weight, forward_batch
308
+ input_ids, hidden_states, self.lm_head, forward_batch
310
309
  )
311
310
 
312
311
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -326,11 +325,6 @@ class OlmoForCausalLM(nn.Module):
326
325
  # Models trained using ColossalAI may include these tensors in
327
326
  # the checkpoint. Skip them.
328
327
  continue
329
- # With tie_word_embeddings, we can skip lm_head.weight
330
- # The weight might appear unnecessarily in the files if the model is
331
- # processed with quantization, LoRA, fine-tuning, etc.
332
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
333
- continue
334
328
  for param_name, weight_name, shard_id in stacked_params_mapping:
335
329
  if weight_name not in name:
336
330
  continue
@@ -312,7 +312,6 @@ class Olmo2ForCausalLM(nn.Module):
312
312
  def __init__(
313
313
  self,
314
314
  config: PretrainedConfig,
315
- cache_config=None,
316
315
  quant_config: Optional[QuantizationConfig] = None,
317
316
  ):
318
317
  super().__init__()
@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
- from vllm.utils import print_warning_once
39
37
 
40
38
  from sglang.srt.layers.activation import SiluAndMul
41
39
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
46
  VocabParallelEmbedding,
49
47
  )
50
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
51
- from sglang.srt.utils import make_layers
49
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
50
+ from sglang.srt.utils import make_layers, print_warning_once
52
51
 
53
52
 
54
53
  class OlmoeMoE(nn.Module):
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
300
299
  def __init__(
301
300
  self,
302
301
  config: PretrainedConfig,
303
- cache_config=None,
304
302
  quant_config: Optional[QuantizationConfig] = None,
305
303
  ) -> None:
306
304
  super().__init__()
@@ -321,7 +319,7 @@ class OlmoeForCausalLM(nn.Module):
321
319
  ) -> torch.Tensor:
322
320
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
323
321
  return self.logits_processor(
324
- input_ids, hidden_states, self.lm_head.weight, forward_batch
322
+ input_ids, hidden_states, self.lm_head, forward_batch
325
323
  )
326
324
 
327
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -7,8 +7,6 @@ from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
8
  from vllm.distributed import get_tensor_model_parallel_world_size
9
9
  from vllm.model_executor.layers.rotary_embedding import get_rope
10
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
- from vllm.model_executor.models.utils import make_layers
12
10
 
13
11
  from sglang.srt.layers.linear import (
14
12
  MergedColumnParallelLinear,
@@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
27
25
  )
28
26
  from sglang.srt.managers.schedule_batch import global_server_args_dict
29
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.utils import make_layers
30
30
 
31
31
 
32
32
  @torch.jit.script
@@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module):
235
235
  self,
236
236
  config: PretrainedConfig,
237
237
  layer_id: int,
238
- cache_config=None,
239
238
  quant_config: Optional[QuantizationConfig] = None,
240
239
  ):
241
240
  super().__init__()
@@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module):
286
285
  super().__init__()
287
286
 
288
287
  self.config = config
289
- cache_config = None
290
288
  self.embed_tokens = VocabParallelEmbedding(
291
289
  config.vocab_size, config.hidden_size
292
290
  )
@@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module):
294
292
  self.start_layer, self.end_layer, self.layers = make_layers(
295
293
  config.num_hidden_layers,
296
294
  lambda prefix: Phi3SmallDecoderLayer(
297
- config, int(prefix.split(".")[-1]), cache_config, quant_config
295
+ config, int(prefix.split(".")[-1]), quant_config
298
296
  ),
299
297
  prefix=f"{prefix}.layers",
300
298
  )
@@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module):
339
337
  self,
340
338
  config: Phi3Config,
341
339
  quant_config: Optional[QuantizationConfig] = None,
342
- cache_config=None,
343
340
  ):
344
341
 
345
342
  super().__init__()
@@ -397,10 +394,13 @@ class Phi3SmallForCausalLM(nn.Module):
397
394
 
398
395
  def compute_logits(
399
396
  self,
397
+ input_ids: torch.LongTensor,
400
398
  hidden_states: torch.Tensor,
401
399
  sampling_metadata,
402
400
  ) -> Optional[torch.Tensor]:
403
- logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
401
+ logits = self.logits_processor(
402
+ input_ids, self.lm_head, hidden_states, sampling_metadata
403
+ )
404
404
  if self.dummy_token_indices is not None and logits is not None:
405
405
  logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
406
406
  return logits
@@ -422,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module):
422
422
 
423
423
  if not get_embedding:
424
424
  return self.logits_processor(
425
- input_ids, hidden_states, self.lm_head.weight, forward_batch
425
+ input_ids, hidden_states, self.lm_head, forward_batch
426
426
  )
427
427
 
428
428
  else:
sglang/srt/models/qwen.py CHANGED
@@ -22,7 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
38
  VocabParallelEmbedding,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
 
43
43
 
44
44
  class QWenMLP(nn.Module):
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
242
242
  self,
243
243
  config: PretrainedConfig,
244
244
  quant_config: Optional[QuantizationConfig] = None,
245
- cache_config=None,
246
245
  ):
247
246
  super().__init__()
248
247
  self.config = config
@@ -260,7 +259,7 @@ class QWenLMHeadModel(nn.Module):
260
259
  ):
261
260
  hidden_states = self.transformer(input_ids, positions, forward_batch)
262
261
  return self.logits_processor(
263
- input_ids, hidden_states, self.lm_head.weight, forward_batch
262
+ input_ids, hidden_states, self.lm_head, forward_batch
264
263
  )
265
264
 
266
265
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  VocabParallelEmbedding,
41
40
  )
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.utils import make_layers
44
44
 
45
45
  Qwen2Config = None
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
230
230
  self.embed_tokens = VocabParallelEmbedding(
231
231
  config.vocab_size,
232
232
  config.hidden_size,
233
+ quant_config=quant_config,
233
234
  )
234
235
  self.layers = make_layers(
235
236
  config.num_hidden_layers,
@@ -270,13 +271,17 @@ class Qwen2ForCausalLM(nn.Module):
270
271
  self,
271
272
  config: Qwen2Config,
272
273
  quant_config: Optional[QuantizationConfig] = None,
273
- cache_config=None,
274
274
  ) -> None:
275
275
  super().__init__()
276
276
  self.config = config
277
277
  self.quant_config = quant_config
278
278
  self.model = Qwen2Model(config, quant_config=quant_config)
279
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
279
+ if config.tie_word_embeddings:
280
+ self.lm_head = self.model.embed_tokens
281
+ else:
282
+ self.lm_head = ParallelLMHead(
283
+ config.vocab_size, config.hidden_size, quant_config=quant_config
284
+ )
280
285
  self.logits_processor = LogitsProcessor(config)
281
286
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
282
287
 
@@ -292,7 +297,7 @@ class Qwen2ForCausalLM(nn.Module):
292
297
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
293
298
  if not get_embedding:
294
299
  return self.logits_processor(
295
- input_ids, hidden_states, self.lm_head.weight, forward_batch
300
+ input_ids, hidden_states, self.lm_head, forward_batch
296
301
  )
297
302
  else:
298
303
  return self.pooler(hidden_states, forward_batch)
@@ -306,6 +311,7 @@ class Qwen2ForCausalLM(nn.Module):
306
311
  ("gate_up_proj", "gate_proj", 0),
307
312
  ("gate_up_proj", "up_proj", 1),
308
313
  ]
314
+
309
315
  params_dict = dict(self.named_parameters())
310
316
  for name, loaded_weight in weights:
311
317
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -335,11 +341,6 @@ class Qwen2ForCausalLM(nn.Module):
335
341
  param = params_dict[name]
336
342
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
337
343
  weight_loader(param, loaded_weight)
338
- if (
339
- self.config.tie_word_embeddings
340
- and name == "model.embed_tokens.weight"
341
- ):
342
- weight_loader(params_dict["lm_head.weight"], loaded_weight)
343
344
 
344
345
 
345
346
  EntryClass = Qwen2ForCausalLM
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
47
  )
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
51
51
 
52
52
 
53
53
  class Qwen2MoeMLP(nn.Module):
@@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module):
158
158
  rope_theta: float = 10000,
159
159
  rope_scaling: Optional[Dict[str, Any]] = None,
160
160
  max_position_embeddings: int = 8192,
161
- cache_config=None,
162
161
  quant_config: Optional[QuantizationConfig] = None,
163
162
  ) -> None:
164
163
  super().__init__()
@@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
234
233
  self,
235
234
  config: PretrainedConfig,
236
235
  layer_id: int,
237
- cache_config=None,
238
236
  quant_config: Optional[QuantizationConfig] = None,
239
237
  ) -> None:
240
238
  super().__init__()
@@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
250
248
  rope_theta=rope_theta,
251
249
  rope_scaling=rope_scaling,
252
250
  max_position_embeddings=max_position_embeddings,
253
- cache_config=cache_config,
254
251
  quant_config=quant_config,
255
252
  )
256
253
 
@@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module):
304
301
  def __init__(
305
302
  self,
306
303
  config: PretrainedConfig,
307
- cache_config=None,
308
304
  quant_config: Optional[QuantizationConfig] = None,
309
305
  ) -> None:
310
306
  super().__init__()
@@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module):
317
313
  )
318
314
  self.layers = nn.ModuleList(
319
315
  [
320
- Qwen2MoeDecoderLayer(
321
- config, layer_id, cache_config, quant_config=quant_config
322
- )
316
+ Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
323
317
  for layer_id in range(config.num_hidden_layers)
324
318
  ]
325
319
  )
@@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module):
353
347
  def __init__(
354
348
  self,
355
349
  config: PretrainedConfig,
356
- cache_config=None,
357
350
  quant_config: Optional[QuantizationConfig] = None,
358
351
  ) -> None:
359
352
  super().__init__()
360
353
  self.config = config
361
354
  self.quant_config = quant_config
362
355
  self.torchao_config = global_server_args_dict["torchao_config"]
363
- self.model = Qwen2MoeModel(config, cache_config, quant_config)
356
+ self.model = Qwen2MoeModel(config, quant_config)
364
357
  self.lm_head = ParallelLMHead(
365
358
  config.vocab_size, config.hidden_size, quant_config=quant_config
366
359
  )
@@ -376,7 +369,7 @@ class Qwen2MoeForCausalLM(nn.Module):
376
369
  ) -> torch.Tensor:
377
370
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
378
371
  return self.logits_processor(
379
- input_ids, hidden_states, self.lm_head.weight, forward_batch
372
+ input_ids, hidden_states, self.lm_head, forward_batch
380
373
  )
381
374
 
382
375
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -30,12 +30,10 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange, repeat
33
- from vllm.config import CacheConfig, MultiModalConfig
34
33
  from vllm.distributed import parallel_state
35
34
  from vllm.distributed import utils as dist_utils
36
35
  from vllm.logger import init_logger
37
36
  from vllm.model_executor.layers.activation import QuickGELU
38
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
40
38
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
41
39
  from sglang.srt.hf_transformers_utils import get_processor
@@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
47
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
48
  from sglang.srt.managers.schedule_batch import ImageInputs
51
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
52
51
  from sglang.srt.models.qwen2 import Qwen2Model
53
52
 
54
53
  logger = init_logger(__name__)
@@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
536
535
  def __init__(
537
536
  self,
538
537
  config: Qwen2VLConfig,
539
- cache_config: Optional[CacheConfig] = None,
540
538
  quant_config: Optional[QuantizationConfig] = None,
541
539
  ) -> None:
542
540
  super().__init__()
@@ -668,7 +666,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
668
666
 
669
667
  if not get_embedding:
670
668
  return self.logits_processor(
671
- input_ids, hidden_states, self.lm_head.weight, forward_batch
669
+ input_ids, hidden_states, self.lm_head, forward_batch
672
670
  )
673
671
  else:
674
672
  return self.pooler(hidden_states, forward_batch)
@@ -686,8 +684,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
686
684
  for name, loaded_weight in weights:
687
685
  if "rotary_emb.inv_freq" in name:
688
686
  continue
689
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
690
- continue
691
687
  for param_name, weight_name, shard_id in stacked_params_mapping:
692
688
  if weight_name not in name:
693
689
  continue
@@ -0,0 +1,99 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
2
+
3
+ import importlib
4
+ import logging
5
+ import pkgutil
6
+ from dataclasses import dataclass, field
7
+ from functools import lru_cache
8
+ from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union
9
+
10
+ import torch.nn as nn
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class _ModelRegistry:
17
+ # Keyed by model_arch
18
+ models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
19
+
20
+ def get_supported_archs(self) -> AbstractSet[str]:
21
+ return self.models.keys()
22
+
23
+ def _raise_for_unsupported(self, architectures: List[str]):
24
+ all_supported_archs = self.get_supported_archs()
25
+
26
+ if any(arch in all_supported_archs for arch in architectures):
27
+ raise ValueError(
28
+ f"Model architectures {architectures} failed "
29
+ "to be inspected. Please check the logs for more details."
30
+ )
31
+
32
+ raise ValueError(
33
+ f"Model architectures {architectures} are not supported for now. "
34
+ f"Supported architectures: {all_supported_archs}"
35
+ )
36
+
37
+ def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
38
+ if model_arch not in self.models:
39
+ return None
40
+
41
+ return self.models[model_arch]
42
+
43
+ def _normalize_archs(
44
+ self,
45
+ architectures: Union[str, List[str]],
46
+ ) -> List[str]:
47
+ if isinstance(architectures, str):
48
+ architectures = [architectures]
49
+ if not architectures:
50
+ logger.warning("No model architectures are specified")
51
+
52
+ return architectures
53
+
54
+ def resolve_model_cls(
55
+ self,
56
+ architectures: Union[str, List[str]],
57
+ ) -> Tuple[Type[nn.Module], str]:
58
+ architectures = self._normalize_archs(architectures)
59
+
60
+ for arch in architectures:
61
+ model_cls = self._try_load_model_cls(arch)
62
+ if model_cls is not None:
63
+ return (model_cls, arch)
64
+
65
+ return self._raise_for_unsupported(architectures)
66
+
67
+
68
+ @lru_cache()
69
+ def import_model_classes():
70
+ model_arch_name_to_cls = {}
71
+ package_name = "sglang.srt.models"
72
+ package = importlib.import_module(package_name)
73
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
74
+ if not ispkg:
75
+ try:
76
+ module = importlib.import_module(name)
77
+ except Exception as e:
78
+ logger.warning(f"Ignore import error when loading {name}. " f"{e}")
79
+ continue
80
+ if hasattr(module, "EntryClass"):
81
+ entry = module.EntryClass
82
+ if isinstance(
83
+ entry, list
84
+ ): # To support multiple model classes in one module
85
+ for tmp in entry:
86
+ assert (
87
+ tmp.__name__ not in model_arch_name_to_cls
88
+ ), f"Duplicated model implementation for {tmp.__name__}"
89
+ model_arch_name_to_cls[tmp.__name__] = tmp
90
+ else:
91
+ assert (
92
+ entry.__name__ not in model_arch_name_to_cls
93
+ ), f"Duplicated model implementation for {entry.__name__}"
94
+ model_arch_name_to_cls[entry.__name__] = entry
95
+
96
+ return model_arch_name_to_cls
97
+
98
+
99
+ ModelRegistry = _ModelRegistry(import_model_classes())
@@ -26,7 +26,6 @@ from torch import nn
26
26
  from transformers import PretrainedConfig
27
27
  from vllm.distributed import get_tensor_model_parallel_world_size
28
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
29
 
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.linear import (
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
41
  VocabParallelEmbedding,
43
42
  )
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
 
47
47
  class StablelmMLP(nn.Module):
@@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module):
242
242
  self,
243
243
  config: PretrainedConfig,
244
244
  quant_config: Optional[QuantizationConfig] = None,
245
- cache_config=None,
246
245
  ) -> None:
247
246
  super().__init__()
248
247
  self.config = config
@@ -261,7 +260,7 @@ class StableLmForCausalLM(nn.Module):
261
260
  ) -> torch.Tensor:
262
261
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
263
262
  return self.logits_processor(
264
- input_ids, hidden_states, self.lm_head.weight, forward_batch
263
+ input_ids, hidden_states, self.lm_head, forward_batch
265
264
  )
266
265
 
267
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):