sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.config import LoRAConfig
24
23
  from vllm.distributed import get_tensor_model_parallel_world_size
25
24
 
26
- # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
27
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
-
29
25
  from sglang.srt.layers.activation import GeluAndMul
30
26
  from sglang.srt.layers.layernorm import GemmaRMSNorm
31
27
  from sglang.srt.layers.linear import (
@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
34
  from sglang.srt.layers.radix_attention import RadixAttention
39
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
38
  from sglang.srt.utils import make_layers
42
39
 
43
40
 
@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
106
103
  head_dim: int,
107
104
  max_position_embeddings: int,
108
105
  rope_theta: float,
109
- cache_config=None,
110
106
  quant_config: Optional[QuantizationConfig] = None,
111
107
  ) -> None:
112
108
  super().__init__()
@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
191
187
  self,
192
188
  layer_id: int,
193
189
  config: PretrainedConfig,
194
- cache_config=None,
195
190
  quant_config: Optional[QuantizationConfig] = None,
196
191
  ) -> None:
197
192
  super().__init__()
@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
205
200
  head_dim=config.head_dim,
206
201
  max_position_embeddings=config.max_position_embeddings,
207
202
  rope_theta=config.rope_theta,
208
- cache_config=cache_config,
209
203
  quant_config=quant_config,
210
204
  )
211
205
  self.hidden_size = config.hidden_size
@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
258
252
  def __init__(
259
253
  self,
260
254
  config: PretrainedConfig,
261
- cache_config=None,
262
255
  quant_config: Optional[QuantizationConfig] = None,
263
256
  ) -> None:
264
257
  super().__init__()
@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
273
266
  lambda idx, prefix: Gemma2DecoderLayer(
274
267
  layer_id=idx,
275
268
  config=config,
276
- cache_config=cache_config,
277
269
  quant_config=quant_config,
278
270
  ),
279
271
  prefix="",
@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
342
334
  def __init__(
343
335
  self,
344
336
  config: PretrainedConfig,
345
- cache_config=None,
346
337
  quant_config: Optional[QuantizationConfig] = None,
347
- lora_config: Optional[LoRAConfig] = None,
348
338
  ) -> None:
349
- del lora_config # Unused.
350
339
  super().__init__()
351
340
  self.config = config
352
341
  self.quant_config = quant_config
353
- self.model = Gemma2Model(config, cache_config, quant_config)
342
+ self.model = Gemma2Model(config, quant_config)
354
343
  self.logits_processor = LogitsProcessor(config)
355
344
 
356
345
  @torch.no_grad()
@@ -363,7 +352,7 @@ class Gemma2ForCausalLM(nn.Module):
363
352
  ) -> torch.Tensor:
364
353
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
365
354
  return self.logits_processor(
366
- input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
355
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
367
356
  )
368
357
 
369
358
  def get_attention_sliding_window_size(self):
@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module):
29
29
  self,
30
30
  config: Gemma2Config,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
sglang/srt/models/gpt2.py CHANGED
@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import GPT2Config
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.activation import get_act_fn
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
28
 
31
29
  # from sglang.srt.layers.activation import get_act_fn
32
30
  from sglang.srt.layers.linear import (
@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
37
  from sglang.srt.layers.radix_attention import RadixAttention
40
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
41
 
43
42
 
44
43
  class GPT2Attention(nn.Module):
@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
47
46
  self,
48
47
  layer_id: int,
49
48
  config: GPT2Config,
50
- cache_config=None,
51
49
  quant_config: Optional[QuantizationConfig] = None,
52
50
  prefix: str = "",
53
51
  ):
@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
140
138
  self,
141
139
  layer_id: int,
142
140
  config: GPT2Config,
143
- cache_config=None,
144
141
  quant_config: Optional[QuantizationConfig] = None,
145
142
  prefix: str = "",
146
143
  ):
@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
150
147
 
151
148
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
152
149
  self.attn = GPT2Attention(
153
- layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
150
+ layer_id, config, quant_config, prefix=f"{prefix}.attn"
154
151
  )
155
152
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
153
  self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
182
179
  def __init__(
183
180
  self,
184
181
  config: GPT2Config,
185
- cache_config=None,
186
182
  quant_config: Optional[QuantizationConfig] = None,
187
183
  prefix: str = "",
188
184
  ):
@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
196
192
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
197
193
  self.h = nn.ModuleList(
198
194
  [
199
- GPT2Block(i, config, cache_config, quant_config)
195
+ GPT2Block(i, config, quant_config)
200
196
  for i in range(config.num_hidden_layers)
201
197
  ]
202
198
  )
@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
226
222
  def __init__(
227
223
  self,
228
224
  config: GPT2Config,
229
- cache_config=None,
230
225
  quant_config: Optional[QuantizationConfig] = None,
231
226
  ):
232
227
  super().__init__()
233
228
  self.config = config
234
229
  self.quant_config = quant_config
235
- self.transformer = GPT2Model(
236
- config, cache_config, quant_config, prefix="transformer"
237
- )
230
+ self.transformer = GPT2Model(config, quant_config, prefix="transformer")
238
231
  self.lm_head = self.transformer.wte
239
232
 
240
233
  self.logits_processor = LogitsProcessor(config)
@@ -247,7 +240,7 @@ class GPT2LMHeadModel(nn.Module):
247
240
  ) -> torch.Tensor:
248
241
  hidden_states = self.transformer(input_ids, positions, forward_batch)
249
242
  return self.logits_processor(
250
- input_ids, hidden_states, self.lm_head.weight, forward_batch
243
+ input_ids, hidden_states, self.lm_head, forward_batch
251
244
  )
252
245
 
253
246
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.config import LoRAConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
25
 
28
26
  from sglang.srt.layers.activation import get_act_fn
29
27
  from sglang.srt.layers.linear import (
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
34
  from sglang.srt.layers.radix_attention import RadixAttention
37
35
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
38
 
40
39
 
41
40
  class GPTBigCodeAttention(nn.Module):
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
44
43
  self,
45
44
  layer_id: int,
46
45
  config: GPTBigCodeConfig,
47
- cache_config=None,
48
46
  quant_config: Optional[QuantizationConfig] = None,
49
47
  ):
50
48
  super().__init__()
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
145
143
  self,
146
144
  layer_id: int,
147
145
  config: GPTBigCodeConfig,
148
- cache_config=None,
149
146
  quant_config: Optional[QuantizationConfig] = None,
150
147
  ):
151
148
  super().__init__()
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
153
150
  inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
154
151
 
155
152
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
156
- self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
153
+ self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
157
154
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
158
155
  self.mlp = GPTBigMLP(inner_dim, config, quant_config)
159
156
 
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
183
180
  def __init__(
184
181
  self,
185
182
  config: GPTBigCodeConfig,
186
- cache_config=None,
187
183
  quant_config: Optional[QuantizationConfig] = None,
188
- lora_config: Optional[LoRAConfig] = None,
189
184
  ):
190
185
  super().__init__()
191
186
  self.config = config
192
187
  assert not config.add_cross_attention
193
188
 
194
189
  self.embed_dim = config.hidden_size
195
- lora_vocab = (
196
- (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
197
- if lora_config
198
- else 0
199
- )
190
+ lora_vocab = 0
200
191
  self.vocab_size = config.vocab_size + lora_vocab
201
192
  self.wte = VocabParallelEmbedding(
202
193
  self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
204
195
  self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
205
196
  self.h = nn.ModuleList(
206
197
  [
207
- GPTBigCodeBlock(i, config, cache_config, quant_config)
198
+ GPTBigCodeBlock(i, config, quant_config)
208
199
  for i in range(config.num_hidden_layers)
209
200
  ]
210
201
  )
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
243
234
  def __init__(
244
235
  self,
245
236
  config: GPTBigCodeConfig,
246
- cache_config=None,
247
237
  quant_config: Optional[QuantizationConfig] = None,
248
- lora_config: Optional[LoRAConfig] = None,
249
238
  ):
250
239
  super().__init__()
251
240
 
252
241
  self.config = config
253
- self.lora_config = lora_config
254
242
 
255
243
  self.quant_config = quant_config
256
- self.transformer = GPTBigCodeModel(
257
- config, cache_config, quant_config, lora_config
258
- )
244
+ self.transformer = GPTBigCodeModel(config, quant_config)
259
245
  self.lm_head = self.transformer.wte
260
246
  self.unpadded_vocab_size = config.vocab_size
261
- if lora_config:
262
- self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
263
247
  self.logits_processor = LogitsProcessor(config)
264
248
 
265
249
  @torch.no_grad()
@@ -271,7 +255,7 @@ class GPTBigCodeForCausalLM(nn.Module):
271
255
  ) -> torch.Tensor:
272
256
  hidden_states = self.transformer(input_ids, positions, forward_batch)
273
257
  return self.logits_processor(
274
- input_ids, hidden_states, self.lm_head.weight, forward_batch
258
+ input_ids, hidden_states, self.lm_head, forward_batch
275
259
  )
276
260
 
277
261
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/grok.py CHANGED
@@ -24,7 +24,6 @@ from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
27
 
29
28
  from sglang.srt.layers.fused_moe_triton import FusedMoE
30
29
  from sglang.srt.layers.layernorm import RMSNorm
@@ -36,13 +35,13 @@ from sglang.srt.layers.linear import (
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
40
38
  from sglang.srt.layers.vocab_parallel_embedding import (
41
39
  ParallelLMHead,
42
40
  VocabParallelEmbedding,
43
41
  )
44
- from sglang.srt.managers.schedule_batch import global_server_args_dict
45
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.loader import DefaultModelLoader
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
45
 
47
46
 
48
47
  class Grok1MoE(nn.Module):
@@ -285,12 +284,10 @@ class Grok1ForCausalLM(nn.Module):
285
284
  self,
286
285
  config: PretrainedConfig,
287
286
  quant_config: Optional[QuantizationConfig] = None,
288
- cache_config=None,
289
287
  ) -> None:
290
288
  super().__init__()
291
289
  self.config = config
292
290
  self.quant_config = quant_config
293
- self.torchao_config = global_server_args_dict["torchao_config"]
294
291
  self.model = Grok1Model(config, quant_config=quant_config)
295
292
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
296
293
  self.logits_processor = LogitsProcessor(config)
@@ -304,7 +301,7 @@ class Grok1ForCausalLM(nn.Module):
304
301
  ) -> torch.Tensor:
305
302
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
306
303
  return self.logits_processor(
307
- input_ids, hidden_states, self.lm_head.weight, forward_batch
304
+ input_ids, hidden_states, self.lm_head, forward_batch
308
305
  )
309
306
 
310
307
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -374,8 +371,6 @@ class Grok1ForCausalLM(nn.Module):
374
371
  )
375
372
  weight_loader(param, loaded_weight)
376
373
 
377
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
378
-
379
374
 
380
375
  class Grok1ModelForCausalLM(Grok1ForCausalLM):
381
376
  """An alias for backward-compatbility."""
@@ -21,7 +21,6 @@ from torch import nn
21
21
  from transformers import PretrainedConfig
22
22
  from vllm.distributed import get_tensor_model_parallel_world_size
23
23
  from vllm.model_executor.layers.rotary_embedding import get_rope
24
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
24
 
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
37
  VocabParallelEmbedding,
39
38
  )
40
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
41
 
42
42
 
43
43
  class InternLM2MLP(nn.Module):
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
251
251
  self,
252
252
  config: PretrainedConfig,
253
253
  quant_config: Optional[QuantizationConfig] = None,
254
- cache_config=None,
255
254
  ) -> None:
256
255
  super().__init__()
257
256
  self.config = config
@@ -270,7 +269,7 @@ class InternLM2ForCausalLM(nn.Module):
270
269
  ) -> torch.Tensor:
271
270
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
272
271
  return self.logits_processor(
273
- input_ids, hidden_states, self.output.weight, forward_batch
272
+ input_ids, hidden_states, self.output, forward_batch
274
273
  )
275
274
 
276
275
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
29
29
  self,
30
30
  config: PretrainedConfig,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
33
  super().__init__()
35
34
  self.config = config
@@ -16,6 +16,7 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
+ import logging
19
20
  from typing import Any, Dict, Iterable, Optional, Tuple
20
21
 
21
22
  import torch
@@ -23,7 +24,6 @@ from torch import nn
23
24
  from transformers import LlamaConfig
24
25
  from vllm.distributed import get_tensor_model_parallel_world_size
25
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
27
 
28
28
  from sglang.srt.layers.activation import SiluAndMul
29
29
  from sglang.srt.layers.layernorm import RMSNorm
@@ -36,14 +36,16 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
36
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
40
39
  from sglang.srt.layers.vocab_parallel_embedding import (
41
40
  ParallelLMHead,
42
41
  VocabParallelEmbedding,
43
42
  )
44
- from sglang.srt.managers.schedule_batch import global_server_args_dict
45
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
45
  from sglang.srt.utils import make_layers
46
+ from sglang.utils import get_exception_traceback
47
+
48
+ logger = logging.getLogger(__name__)
47
49
 
48
50
 
49
51
  class LlamaMLP(nn.Module):
@@ -255,6 +257,7 @@ class LlamaModel(nn.Module):
255
257
  self.embed_tokens = VocabParallelEmbedding(
256
258
  config.vocab_size,
257
259
  config.hidden_size,
260
+ quant_config=quant_config,
258
261
  )
259
262
  self.layers = make_layers(
260
263
  config.num_hidden_layers,
@@ -295,16 +298,29 @@ class LlamaForCausalLM(nn.Module):
295
298
  self,
296
299
  config: LlamaConfig,
297
300
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config=None,
299
301
  ) -> None:
300
302
  super().__init__()
301
303
  self.config = config
302
304
  self.quant_config = quant_config
303
- self.torchao_config = global_server_args_dict["torchao_config"]
304
305
  self.model = LlamaModel(config, quant_config=quant_config)
305
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
306
+ # Llama 3.2 1B Insturct set tie_word_embeddings to True
307
+ # Llama 3.1 8B Insturct set tie_word_embeddings to False
308
+ if self.config.tie_word_embeddings:
309
+ self.lm_head = self.model.embed_tokens
310
+ else:
311
+ self.lm_head = ParallelLMHead(
312
+ config.vocab_size, config.hidden_size, quant_config=quant_config
313
+ )
306
314
  self.logits_processor = LogitsProcessor(config)
307
315
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
316
+ self.stacked_params_mapping = [
317
+ # (param_name, shard_name, shard_id)
318
+ (".qkv_proj", ".q_proj", "q"),
319
+ (".qkv_proj", ".k_proj", "k"),
320
+ (".qkv_proj", ".v_proj", "v"),
321
+ (".gate_up_proj", ".gate_proj", 0),
322
+ (".gate_up_proj", ".up_proj", 1),
323
+ ]
308
324
 
309
325
  @torch.no_grad()
310
326
  def forward(
@@ -318,7 +334,7 @@ class LlamaForCausalLM(nn.Module):
318
334
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
319
335
  if not get_embedding:
320
336
  return self.logits_processor(
321
- input_ids, hidden_states, self.lm_head.weight, forward_batch
337
+ input_ids, hidden_states, self.lm_head, forward_batch
322
338
  )
323
339
  else:
324
340
  return self.pooler(hidden_states, forward_batch)
@@ -349,15 +365,7 @@ class LlamaForCausalLM(nn.Module):
349
365
  return params_mapping.get(name, name)
350
366
 
351
367
  def get_module_name_from_weight_name(self, name):
352
- stacked_params_mapping = [
353
- # (param_name, shard_name, shard_id, num_shard)
354
- ("qkv_proj", "q_proj", "q", 3),
355
- ("qkv_proj", "k_proj", "k", 3),
356
- ("qkv_proj", "v_proj", "v", 3),
357
- ("gate_up_proj", "gate_proj", 0, 2),
358
- ("gate_up_proj", "up_proj", 1, 2),
359
- ]
360
- for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
368
+ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
361
369
  if weight_name in name:
362
370
  return (
363
371
  name.replace(weight_name, param_name)[: -len(".weight")],
@@ -378,13 +386,8 @@ class LlamaForCausalLM(nn.Module):
378
386
  (".gate_up_proj", ".gate_proj", 0),
379
387
  (".gate_up_proj", ".up_proj", 1),
380
388
  ]
381
- params_dict = dict(self.named_parameters())
382
389
 
383
- load_tie_word_embeddings = (
384
- hasattr(self.config, "tie_word_embeddings")
385
- and self.config.tie_word_embeddings
386
- and "lm_head.weight" in params_dict
387
- )
390
+ params_dict = dict(self.named_parameters())
388
391
 
389
392
  for name, loaded_weight in weights:
390
393
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -418,16 +421,78 @@ class LlamaForCausalLM(nn.Module):
418
421
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
419
422
  weight_loader(param, loaded_weight)
420
423
 
421
- if load_tie_word_embeddings and name == "model.embed_tokens.weight":
422
- embed_tokens_weight = loaded_weight
423
-
424
- if load_tie_word_embeddings:
425
- # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
426
- param = self.lm_head.weight
427
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
428
- weight_loader(param, embed_tokens_weight)
424
+ def get_weights_by_name(
425
+ self, name: str, truncate_size: int = 100, tp_size: int = 1
426
+ ) -> Optional[torch.Tensor]:
427
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
428
+
429
+ Only used for unit test with an unoptimized performance.
430
+ For optimized performance, please use torch.save and torch.load.
431
+ """
432
+ try:
433
+ if name == "lm_head.weight" and self.config.tie_word_embeddings:
434
+ logger.info(
435
+ "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
436
+ )
437
+ return (
438
+ self.model.embed_tokens.weight.cpu()
439
+ .to(torch.float32)
440
+ .numpy()
441
+ .tolist()[:truncate_size]
442
+ )
429
443
 
430
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
444
+ mapped_name = name
445
+ mapped_shard_id = None
446
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
447
+ if weight_name in name:
448
+ mapped_name = name.replace(weight_name, param_name)
449
+ mapped_shard_id = shard_id
450
+ break
451
+ params_dict = dict(self.named_parameters())
452
+ param = params_dict[mapped_name]
453
+ if mapped_shard_id is not None:
454
+ if mapped_shard_id in ["q", "k", "v"]:
455
+ num_heads = self.config.num_attention_heads // tp_size
456
+ num_kv_heads = self.config.num_key_value_heads // tp_size
457
+ head_dim = (
458
+ self.config.hidden_size // self.config.num_attention_heads
459
+ )
460
+ if mapped_shard_id == "q":
461
+ offset = 0
462
+ size = num_heads * head_dim
463
+ elif mapped_shard_id == "k":
464
+ offset = num_heads * head_dim
465
+ size = num_kv_heads * head_dim
466
+ elif mapped_shard_id == "v":
467
+ offset = (num_heads + num_kv_heads) * head_dim
468
+ size = num_kv_heads * head_dim
469
+ weight = param.data.narrow(0, offset, size)
470
+ elif mapped_shard_id in [0, 1]:
471
+ intermediate_size = self.config.intermediate_size
472
+ slice_size = intermediate_size // tp_size
473
+ if mapped_shard_id == 0: # gate_proj
474
+ offset = 0
475
+ size = slice_size
476
+ elif mapped_shard_id == 1: # up_proj
477
+ offset = slice_size
478
+ size = slice_size
479
+
480
+ weight = param.data.narrow(0, offset, size)
481
+ else:
482
+ weight = param.data
483
+ else:
484
+ weight = param.data
485
+ if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
486
+ gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
487
+ torch.distributed.all_gather(gathered_weights, weight)
488
+ weight = torch.cat(gathered_weights, dim=1)
489
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
490
+
491
+ except Exception:
492
+ logger.error(
493
+ f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
494
+ )
495
+ return None
431
496
 
432
497
 
433
498
  class Phi3ForCausalLM(LlamaForCausalLM):
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
17
17
  import torch
18
18
  from torch import nn
19
19
  from transformers import LlamaConfig
20
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
20
 
22
21
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
23
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
26
26
 
27
27
 
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
30
30
  self,
31
31
  config: LlamaConfig,
32
32
  quant_config: Optional[QuantizationConfig] = None,
33
- cache_config=None,
34
33
  ) -> None:
35
34
  super().__init__()
36
35
  self.config = config
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
3
3
  import torch
4
4
  from torch import nn
5
5
  from transformers import LlamaConfig
6
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
6
 
8
7
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
8
  from sglang.srt.model_executor.model_runner import ForwardBatch
9
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
10
10
  from sglang.srt.models.llama import LlamaModel
11
11
 
12
12
 
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
15
15
  self,
16
16
  config: LlamaConfig,
17
17
  quant_config=None,
18
- cache_config=None,
19
18
  ) -> None:
20
19
  super().__init__()
21
20
  self.model = LlamaModel(config, quant_config=quant_config)
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
21
21
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
22
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
25
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
25
26
 
26
27
 
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
29
30
  self,
30
31
  config: LlamaConfig,
31
32
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
33
  ) -> None:
34
34
  super().__init__()
35
35
  self.config = config
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
84
84
  self,
85
85
  config: LlamaConfig,
86
86
  quant_config: Optional[QuantizationConfig] = None,
87
- cache_config=None,
88
87
  ) -> None:
89
- super().__init__(config, quant_config, cache_config)
88
+ super().__init__(config, quant_config)
90
89
  self.weights = self.Weights(config.hidden_size, self.num_labels)
91
90
 
92
91
  @torch.no_grad()