sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,6 @@ from transformers import (
29
29
  SiglipVisionModel,
30
30
  )
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
32
 
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.managers.schedule_batch import ImageInputs
@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
39
38
  unpad_image_shape,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
  from sglang.srt.models.llama import LlamaForCausalLM
43
43
  from sglang.srt.models.mistral import MistralForCausalLM
44
44
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
451
451
  self,
452
452
  config: LlavaConfig,
453
453
  quant_config: Optional[QuantizationConfig] = None,
454
- cache_config=None,
455
454
  ) -> None:
456
455
  super().__init__()
457
456
 
@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
473
472
  self,
474
473
  config: LlavaConfig,
475
474
  quant_config: Optional[QuantizationConfig] = None,
476
- cache_config=None,
477
475
  ) -> None:
478
476
  super().__init__()
479
477
 
@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
506
504
  self,
507
505
  config: LlavaConfig,
508
506
  quant_config: Optional[QuantizationConfig] = None,
509
- cache_config=None,
510
507
  ) -> None:
511
508
  super().__init__()
512
509
 
@@ -20,11 +20,11 @@ import torch
20
20
  from torch import nn
21
21
  from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
23
 
25
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
25
  from sglang.srt.managers.schedule_batch import ImageInputs
27
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
29
29
 
30
30
 
@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
33
33
  self,
34
34
  config: LlavaConfig,
35
35
  quant_config: Optional[QuantizationConfig] = None,
36
- cache_config=None,
37
36
  ) -> None:
38
37
  super().__init__()
39
38
  self.config = config
@@ -20,7 +20,6 @@ import torch
20
20
  from torch import nn
21
21
  from vllm.distributed import get_tensor_model_parallel_world_size
22
22
  from vllm.model_executor.layers.rotary_embedding import get_rope
23
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
23
 
25
24
  from sglang.srt.layers.activation import SiluAndMul
26
25
  from sglang.srt.layers.layernorm import RMSNorm
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
 
42
42
  class MiniCPMMLP(nn.Module):
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
275
275
  self,
276
276
  config,
277
277
  quant_config: Optional[QuantizationConfig] = None,
278
- cache_config=None,
279
278
  ) -> None:
280
279
  super().__init__()
281
280
  self.config = config
@@ -308,12 +307,10 @@ class MiniCPMForCausalLM(nn.Module):
308
307
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
309
308
  hidden_states = hidden_states / self.scale_width
310
309
  if self.config.tie_word_embeddings:
311
- lm_head_weight = self.model.embed_tokens.weight
310
+ lm_head = self.model.embed_tokens
312
311
  else:
313
- lm_head_weight = self.lm_head.weight
314
- return self.logits_processor(
315
- input_ids, hidden_states, lm_head_weight, forward_batch
316
- )
312
+ lm_head = self.lm_head
313
+ return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
317
314
 
318
315
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
319
316
  stacked_params_mapping = [
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
27
27
  RowParallelLinear,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  )
41
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.utils import is_flashinfer_available
44
44
 
45
45
  if is_flashinfer_available():
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
105
105
  rope_theta: float = 10000,
106
106
  rope_scaling: Optional[Dict[str, Any]] = None,
107
107
  max_position_embeddings: int = 8192,
108
- cache_config=None,
109
108
  quant_config: Optional[QuantizationConfig] = None,
110
109
  layer_id=None,
111
110
  ) -> None:
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
249
248
  rope_theta: float = 10000,
250
249
  rope_scaling: Optional[Dict[str, Any]] = None,
251
250
  max_position_embeddings: int = 8192,
252
- cache_config=None,
253
251
  quant_config: Optional[QuantizationConfig] = None,
254
252
  layer_id=None,
255
253
  ) -> None:
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
406
404
  self,
407
405
  config: PretrainedConfig,
408
406
  layer_id: int,
409
- cache_config=None,
410
407
  quant_config: Optional[QuantizationConfig] = None,
411
408
  ) -> None:
412
409
  super().__init__()
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
430
427
  rope_theta=rope_theta,
431
428
  rope_scaling=rope_scaling,
432
429
  max_position_embeddings=max_position_embeddings,
433
- cache_config=cache_config,
434
430
  quant_config=quant_config,
435
431
  layer_id=layer_id,
436
432
  )
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
449
445
  rope_theta=rope_theta,
450
446
  rope_scaling=rope_scaling,
451
447
  max_position_embeddings=max_position_embeddings,
452
- cache_config=cache_config,
453
448
  quant_config=quant_config,
454
449
  layer_id=layer_id,
455
450
  )
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
498
493
  def __init__(
499
494
  self,
500
495
  config: PretrainedConfig,
501
- cache_config=None,
502
496
  quant_config: Optional[QuantizationConfig] = None,
503
497
  ) -> None:
504
498
  super().__init__()
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
512
506
  )
513
507
  self.layers = nn.ModuleList(
514
508
  [
515
- MiniCPM3DecoderLayer(
516
- config, i, cache_config=cache_config, quant_config=quant_config
517
- )
509
+ MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
518
510
  for i in range(config.num_hidden_layers)
519
511
  ]
520
512
  )
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
549
541
  def __init__(
550
542
  self,
551
543
  config: PretrainedConfig,
552
- cache_config=None,
553
544
  quant_config: Optional[QuantizationConfig] = None,
554
545
  ) -> None:
555
546
  super().__init__()
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
557
548
 
558
549
  self.num_experts = getattr(self.config, "num_experts", 0)
559
550
  self.quant_config = quant_config
560
- self.model = MiniCPM3Model(
561
- config, cache_config=cache_config, quant_config=quant_config
562
- )
551
+ self.model = MiniCPM3Model(config, quant_config=quant_config)
563
552
  # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
564
553
  if not self.config.tie_word_embeddings:
565
554
  self.lm_head = ParallelLMHead(
@@ -585,12 +574,10 @@ class MiniCPM3ForCausalLM(nn.Module):
585
574
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
586
575
  hidden_states = hidden_states / self.scale_width
587
576
  if self.config.tie_word_embeddings:
588
- lm_head_weight = self.model.embed_tokens.weight
577
+ lm_head = self.model.embed_tokens
589
578
  else:
590
- lm_head_weight = self.lm_head.weight
591
- return self.logits_processor(
592
- input_ids, hidden_states, lm_head_weight, forward_batch
593
- )
579
+ lm_head = self.lm_head
580
+ return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
594
581
 
595
582
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
596
583
  stacked_params_mapping = [
@@ -21,10 +21,13 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
24
+ from vllm.distributed import (
25
+ get_tensor_model_parallel_world_size,
26
+ tensor_model_parallel_all_reduce,
27
+ )
25
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
29
 
30
+ from sglang.srt.layers.ep_moe.layer import EPMoE
28
31
  from sglang.srt.layers.fused_moe_triton import FusedMoE
29
32
  from sglang.srt.layers.layernorm import RMSNorm
30
33
  from sglang.srt.layers.linear import (
@@ -35,13 +38,13 @@ from sglang.srt.layers.linear import (
35
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
40
  from sglang.srt.layers.radix_attention import RadixAttention
38
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
39
41
  from sglang.srt.layers.vocab_parallel_embedding import (
40
42
  ParallelLMHead,
41
43
  VocabParallelEmbedding,
42
44
  )
43
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
48
 
46
49
 
47
50
  class MixtralMoE(nn.Module):
@@ -65,6 +68,7 @@ class MixtralMoE(nn.Module):
65
68
  prefix: str = "",
66
69
  ):
67
70
  super().__init__()
71
+ self.tp_size = get_tensor_model_parallel_world_size()
68
72
  self.hidden_size = hidden_size
69
73
 
70
74
  # Gate always runs at half / full precision for now.
@@ -76,14 +80,13 @@ class MixtralMoE(nn.Module):
76
80
  quant_config=None,
77
81
  prefix=f"{prefix}.gate",
78
82
  )
79
-
80
- self.experts = FusedMoE(
83
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
84
+ self.experts = MoEImpl(
81
85
  num_experts=num_experts,
82
86
  top_k=top_k,
83
87
  hidden_size=hidden_size,
84
88
  intermediate_size=intermediate_size,
85
89
  params_dtype=params_dtype,
86
- reduce_results=True,
87
90
  renormalize=True,
88
91
  quant_config=quant_config,
89
92
  tp_size=tp_size,
@@ -97,6 +100,8 @@ class MixtralMoE(nn.Module):
97
100
  # router_logits: (num_tokens, n_experts)
98
101
  router_logits, _ = self.gate(hidden_states)
99
102
  final_hidden_states = self.experts(hidden_states, router_logits)
103
+ if self.tp_size > 1:
104
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
100
105
  return final_hidden_states.view(orig_shape)
101
106
 
102
107
 
@@ -291,12 +296,10 @@ class MixtralForCausalLM(nn.Module):
291
296
  self,
292
297
  config: MixtralConfig,
293
298
  quant_config: Optional[QuantizationConfig] = None,
294
- cache_config=None,
295
299
  ) -> None:
296
300
  super().__init__()
297
301
  self.config = config
298
302
  self.quant_config = quant_config
299
- self.torchao_config = global_server_args_dict["torchao_config"]
300
303
  self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
301
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
302
305
  self.logits_processor = LogitsProcessor(config)
@@ -310,7 +313,7 @@ class MixtralForCausalLM(nn.Module):
310
313
  ) -> torch.Tensor:
311
314
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
312
315
  return self.logits_processor(
313
- input_ids, hidden_states, self.lm_head.weight, forward_batch
316
+ input_ids, hidden_states, self.lm_head, forward_batch
314
317
  )
315
318
 
316
319
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -323,7 +326,8 @@ class MixtralForCausalLM(nn.Module):
323
326
 
324
327
  # Params for weights, fp8 weight scales, fp8 activation scales
325
328
  # (param_name, weight_name, expert_id, shard_id)
326
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
329
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
330
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
327
331
  ckpt_gate_proj_name="w1",
328
332
  ckpt_down_proj_name="w2",
329
333
  ckpt_up_proj_name="w3",
@@ -340,7 +344,9 @@ class MixtralForCausalLM(nn.Module):
340
344
  continue
341
345
  name = name.replace(weight_name, param_name)
342
346
  # Skip loading extra bias for GPTQ models.
343
- if name.endswith(".bias") and name not in params_dict:
347
+ if (
348
+ name.endswith(".bias") or name.endswith("_bias")
349
+ ) and name not in params_dict:
344
350
  continue
345
351
 
346
352
  param = params_dict[name]
@@ -354,6 +360,10 @@ class MixtralForCausalLM(nn.Module):
354
360
  continue
355
361
  name = name.replace(weight_name, param_name)
356
362
 
363
+ if (
364
+ name.endswith(".bias") or name.endswith("_bias")
365
+ ) and name not in params_dict:
366
+ continue
357
367
  param = params_dict[name]
358
368
  weight_loader = param.weight_loader
359
369
  weight_loader(
@@ -366,7 +376,9 @@ class MixtralForCausalLM(nn.Module):
366
376
  break
367
377
  else:
368
378
  # Skip loading extra bias for GPTQ models.
369
- if name.endswith(".bias") and name not in params_dict:
379
+ if (
380
+ name.endswith(".bias") or name.endswith("_bias")
381
+ ) and name not in params_dict:
370
382
  continue
371
383
  # Skip loading kv_scale from ckpts towards new design.
372
384
  if name.endswith(".kv_scale") and name not in params_dict:
@@ -380,7 +392,5 @@ class MixtralForCausalLM(nn.Module):
380
392
  )
381
393
  weight_loader(param, loaded_weight)
382
394
 
383
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
384
-
385
395
 
386
396
  EntryClass = MixtralForCausalLM
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
32
 
34
33
  from sglang.srt.layers.layernorm import RMSNorm
35
34
  from sglang.srt.layers.linear import (
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
44
  VocabParallelEmbedding,
46
45
  )
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
 
49
49
 
50
50
  class MixtralMLP(nn.Module):
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
324
324
  self,
325
325
  config: MixtralConfig,
326
326
  quant_config: Optional[QuantizationConfig] = None,
327
- cache_config=None,
328
327
  ) -> None:
329
328
  super().__init__()
330
329
  self.config = config
@@ -343,7 +342,7 @@ class QuantMixtralForCausalLM(nn.Module):
343
342
  ) -> torch.Tensor:
344
343
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
344
  return self.logits_processor(
346
- input_ids, hidden_states, self.lm_head.weight, forward_batch
345
+ input_ids, hidden_states, self.lm_head, forward_batch
347
346
  )
348
347
 
349
348
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
15
15
  _prepare_aspect_ratio_attention_mask,
16
16
  )
17
17
  from vllm.distributed import get_tensor_model_parallel_world_size
18
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19
18
 
20
19
  from sglang.srt.layers.activation import get_act_fn
21
20
  from sglang.srt.layers.layernorm import RMSNorm
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
34
33
  )
35
34
  from sglang.srt.managers.schedule_batch import ImageInputs
36
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
37
37
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
38
38
 
39
39
 
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
654
654
  self,
655
655
  config: config_mllama.MllamaTextConfig,
656
656
  quant_config: Optional[QuantizationConfig],
657
- cache_config=None,
658
657
  ):
659
658
  super().__init__()
660
659
  self.padding_id = config.pad_token_id
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
732
731
  self,
733
732
  config: config_mllama.MllamaTextConfig,
734
733
  quant_config: Optional[QuantizationConfig],
735
- cache_config=None,
736
734
  ):
737
735
  super().__init__()
738
736
  self.vocab_size = config.vocab_size
739
- self.model = MllamaTextModel(config, cache_config, quant_config)
737
+ self.model = MllamaTextModel(config, quant_config)
740
738
  self.lm_head = ParallelLMHead(
741
739
  config.vocab_size,
742
740
  config.hidden_size,
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
772
770
  self,
773
771
  config: config_mllama.MllamaConfig,
774
772
  quant_config: Optional[QuantizationConfig] = None,
775
- cache_config=None,
776
773
  ):
777
774
  super().__init__()
778
775
  self.vocab_size = config.text_config.vocab_size
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
787
784
  self.vision_model = MllamaVisionModel(config.vision_config)
788
785
  self.language_model = MllamaForCausalLM(
789
786
  config.text_config,
790
- cache_config=cache_config,
791
787
  quant_config=quant_config,
792
788
  )
793
789
  self.multi_modal_projector = nn.Linear(
@@ -966,7 +962,7 @@ class MllamaForConditionalGeneration(nn.Module):
966
962
  skip_cross_attention=skip_cross_attention,
967
963
  )
968
964
  return self.logits_processor(
969
- input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
965
+ input_ids, hidden_states, self.language_model.lm_head, forward_batch
970
966
  )
971
967
 
972
968
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/olmo.py CHANGED
@@ -22,7 +22,6 @@ from torch import nn
22
22
  from transformers import OlmoConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.linear import (
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
37
  VocabParallelEmbedding,
39
38
  )
40
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.utils import make_layers
42
42
 
43
43
 
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
274
274
  def __init__(
275
275
  self,
276
276
  config: OlmoConfig,
277
- cache_config=None,
278
277
  quant_config: Optional[QuantizationConfig] = None,
279
278
  ):
280
279
  super().__init__()
@@ -306,7 +305,7 @@ class OlmoForCausalLM(nn.Module):
306
305
  input_embeds=input_embeds,
307
306
  )
308
307
  return self.logits_processor(
309
- input_ids, hidden_states, self.lm_head.weight, forward_batch
308
+ input_ids, hidden_states, self.lm_head, forward_batch
310
309
  )
311
310
 
312
311
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -326,11 +325,6 @@ class OlmoForCausalLM(nn.Module):
326
325
  # Models trained using ColossalAI may include these tensors in
327
326
  # the checkpoint. Skip them.
328
327
  continue
329
- # With tie_word_embeddings, we can skip lm_head.weight
330
- # The weight might appear unnecessarily in the files if the model is
331
- # processed with quantization, LoRA, fine-tuning, etc.
332
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
333
- continue
334
328
  for param_name, weight_name, shard_id in stacked_params_mapping:
335
329
  if weight_name not in name:
336
330
  continue
@@ -312,7 +312,6 @@ class Olmo2ForCausalLM(nn.Module):
312
312
  def __init__(
313
313
  self,
314
314
  config: PretrainedConfig,
315
- cache_config=None,
316
315
  quant_config: Optional[QuantizationConfig] = None,
317
316
  ):
318
317
  super().__init__()
@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
- from vllm.utils import print_warning_once
39
37
 
40
38
  from sglang.srt.layers.activation import SiluAndMul
41
39
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
46
  VocabParallelEmbedding,
49
47
  )
50
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
51
- from sglang.srt.utils import make_layers
49
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
50
+ from sglang.srt.utils import make_layers, print_warning_once
52
51
 
53
52
 
54
53
  class OlmoeMoE(nn.Module):
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
300
299
  def __init__(
301
300
  self,
302
301
  config: PretrainedConfig,
303
- cache_config=None,
304
302
  quant_config: Optional[QuantizationConfig] = None,
305
303
  ) -> None:
306
304
  super().__init__()
@@ -321,7 +319,7 @@ class OlmoeForCausalLM(nn.Module):
321
319
  ) -> torch.Tensor:
322
320
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
323
321
  return self.logits_processor(
324
- input_ids, hidden_states, self.lm_head.weight, forward_batch
322
+ input_ids, hidden_states, self.lm_head, forward_batch
325
323
  )
326
324
 
327
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -7,8 +7,6 @@ from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
8
  from vllm.distributed import get_tensor_model_parallel_world_size
9
9
  from vllm.model_executor.layers.rotary_embedding import get_rope
10
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
- from vllm.model_executor.models.utils import make_layers
12
10
 
13
11
  from sglang.srt.layers.linear import (
14
12
  MergedColumnParallelLinear,
@@ -19,14 +17,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
19
17
  from sglang.srt.layers.pooler import Pooler, PoolingType
20
18
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
19
  from sglang.srt.layers.radix_attention import RadixAttention
22
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
23
20
  from sglang.srt.layers.vocab_parallel_embedding import (
24
21
  DEFAULT_VOCAB_PADDING_SIZE,
25
22
  ParallelLMHead,
26
23
  VocabParallelEmbedding,
27
24
  )
28
- from sglang.srt.managers.schedule_batch import global_server_args_dict
29
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
27
+ from sglang.srt.utils import make_layers
30
28
 
31
29
 
32
30
  @torch.jit.script
@@ -235,7 +233,6 @@ class Phi3SmallDecoderLayer(nn.Module):
235
233
  self,
236
234
  config: PretrainedConfig,
237
235
  layer_id: int,
238
- cache_config=None,
239
236
  quant_config: Optional[QuantizationConfig] = None,
240
237
  ):
241
238
  super().__init__()
@@ -286,7 +283,6 @@ class Phi3SmallModel(nn.Module):
286
283
  super().__init__()
287
284
 
288
285
  self.config = config
289
- cache_config = None
290
286
  self.embed_tokens = VocabParallelEmbedding(
291
287
  config.vocab_size, config.hidden_size
292
288
  )
@@ -294,7 +290,7 @@ class Phi3SmallModel(nn.Module):
294
290
  self.start_layer, self.end_layer, self.layers = make_layers(
295
291
  config.num_hidden_layers,
296
292
  lambda prefix: Phi3SmallDecoderLayer(
297
- config, int(prefix.split(".")[-1]), cache_config, quant_config
293
+ config, int(prefix.split(".")[-1]), quant_config
298
294
  ),
299
295
  prefix=f"{prefix}.layers",
300
296
  )
@@ -339,7 +335,6 @@ class Phi3SmallForCausalLM(nn.Module):
339
335
  self,
340
336
  config: Phi3Config,
341
337
  quant_config: Optional[QuantizationConfig] = None,
342
- cache_config=None,
343
338
  ):
344
339
 
345
340
  super().__init__()
@@ -351,7 +346,6 @@ class Phi3SmallForCausalLM(nn.Module):
351
346
  quant_config=quant_config,
352
347
  prefix="model",
353
348
  )
354
- self.torchao_config = global_server_args_dict["torchao_config"]
355
349
  self.vocab_size = config.vocab_size
356
350
  self.mup_width_multiplier = config.mup_width_multiplier
357
351
  self.lm_head = ParallelLMHead(
@@ -397,10 +391,13 @@ class Phi3SmallForCausalLM(nn.Module):
397
391
 
398
392
  def compute_logits(
399
393
  self,
394
+ input_ids: torch.LongTensor,
400
395
  hidden_states: torch.Tensor,
401
396
  sampling_metadata,
402
397
  ) -> Optional[torch.Tensor]:
403
- logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
398
+ logits = self.logits_processor(
399
+ input_ids, self.lm_head, hidden_states, sampling_metadata
400
+ )
404
401
  if self.dummy_token_indices is not None and logits is not None:
405
402
  logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
406
403
  return logits
@@ -422,7 +419,7 @@ class Phi3SmallForCausalLM(nn.Module):
422
419
 
423
420
  if not get_embedding:
424
421
  return self.logits_processor(
425
- input_ids, hidden_states, self.lm_head.weight, forward_batch
422
+ input_ids, hidden_states, self.lm_head, forward_batch
426
423
  )
427
424
 
428
425
  else:
@@ -441,7 +438,5 @@ class Phi3SmallForCausalLM(nn.Module):
441
438
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
442
439
  weight_loader(param, loaded_weight)
443
440
 
444
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
445
-
446
441
 
447
442
  EntryClass = Phi3SmallForCausalLM
sglang/srt/models/qwen.py CHANGED
@@ -22,7 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
38
  VocabParallelEmbedding,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
 
43
43
 
44
44
  class QWenMLP(nn.Module):
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
242
242
  self,
243
243
  config: PretrainedConfig,
244
244
  quant_config: Optional[QuantizationConfig] = None,
245
- cache_config=None,
246
245
  ):
247
246
  super().__init__()
248
247
  self.config = config
@@ -260,7 +259,7 @@ class QWenLMHeadModel(nn.Module):
260
259
  ):
261
260
  hidden_states = self.transformer(input_ids, positions, forward_batch)
262
261
  return self.logits_processor(
263
- input_ids, hidden_states, self.lm_head.weight, forward_batch
262
+ input_ids, hidden_states, self.lm_head, forward_batch
264
263
  )
265
264
 
266
265
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):