sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,287 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
4
+ # Copyright 2023 The vLLM team.
5
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
6
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ """Inference-only GPT-2 model compatible with HuggingFace weights."""
20
+ from typing import Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import GPT2Config
25
+ from vllm.config import CacheConfig
26
+ from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
27
+ from vllm.model_executor.layers.activation import get_act_fn
28
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+
31
+ #from sglang.srt.layers.activation import get_act_fn
32
+ from sglang.srt.layers.linear import (
33
+ ColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
41
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+
43
+
44
+ class GPT2Attention(nn.Module):
45
+
46
+ def __init__(
47
+ self,
48
+ layer_id: int,
49
+ config: GPT2Config,
50
+ cache_config = None,
51
+ quant_config: Optional[QuantizationConfig] = None,
52
+ prefix: str = "",
53
+ ):
54
+ super().__init__()
55
+ self.hidden_size = config.hidden_size
56
+ total_num_heads = config.num_attention_heads
57
+ tensor_model_parallel_world_size = (
58
+ get_tensor_model_parallel_world_size())
59
+ assert total_num_heads % tensor_model_parallel_world_size == 0
60
+ self.num_heads = total_num_heads // tensor_model_parallel_world_size
61
+ self.head_dim = self.hidden_size // total_num_heads
62
+ self.scale = self.head_dim**-0.5
63
+
64
+ self.c_attn = QKVParallelLinear(
65
+ self.hidden_size,
66
+ self.head_dim,
67
+ total_num_heads,
68
+ bias=True,
69
+ quant_config=quant_config,
70
+ prefix=f"{prefix}.c_attn",
71
+ )
72
+ self.c_proj = RowParallelLinear(
73
+ self.hidden_size,
74
+ self.hidden_size,
75
+ bias=True,
76
+ quant_config=quant_config,
77
+ prefix=f"{prefix}.c_proj",
78
+ )
79
+ self.attn = RadixAttention(self.num_heads,
80
+ self.head_dim,
81
+ scaling=self.scale,
82
+ num_kv_heads=total_num_heads,
83
+ layer_id=layer_id)
84
+
85
+ def forward(
86
+ self,
87
+ hidden_states: torch.Tensor,
88
+ forward_batch: ForwardBatch,
89
+ ) -> torch.Tensor:
90
+ qkv, _ = self.c_attn(hidden_states)
91
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
92
+ attn_output = self.attn(q, k, v, forward_batch)
93
+ attn_output, _ = self.c_proj(attn_output)
94
+ return attn_output
95
+
96
+
97
+ class GPT2MLP(nn.Module):
98
+
99
+ def __init__(
100
+ self,
101
+ intermediate_size: int,
102
+ config: GPT2Config,
103
+ quant_config: Optional[QuantizationConfig] = None,
104
+ prefix: str = "",
105
+ ):
106
+ super().__init__()
107
+ hidden_size = config.hidden_size
108
+ self.c_fc = ColumnParallelLinear(
109
+ hidden_size,
110
+ intermediate_size,
111
+ bias=True,
112
+ quant_config=quant_config,
113
+ prefix=f"{prefix}.c_fc",
114
+ )
115
+ self.c_proj = RowParallelLinear(
116
+ intermediate_size,
117
+ hidden_size,
118
+ bias=True,
119
+ quant_config=quant_config,
120
+ prefix=f"{prefix}.c_proj",
121
+ )
122
+ self.act = get_act_fn(config.activation_function, quant_config,
123
+ intermediate_size)
124
+
125
+ def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
126
+ hidden_states, _ = self.c_fc(hidden_states)
127
+ hidden_states = self.act(hidden_states)
128
+ hidden_states, _ = self.c_proj(hidden_states)
129
+ return hidden_states
130
+
131
+
132
+ class GPT2Block(nn.Module):
133
+
134
+ def __init__(
135
+ self,
136
+ layer_id: int,
137
+ config: GPT2Config,
138
+ cache_config = None,
139
+
140
+ quant_config: Optional[QuantizationConfig] = None,
141
+ prefix: str = "",
142
+ ):
143
+ super().__init__()
144
+ hidden_size = config.hidden_size
145
+ inner_dim = (config.n_inner if config.n_inner is not None else 4 *
146
+ hidden_size)
147
+
148
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
149
+ self.attn = GPT2Attention(layer_id,
150
+ config,
151
+ cache_config,
152
+ quant_config,
153
+ prefix=f"{prefix}.attn")
154
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
155
+ self.mlp = GPT2MLP(inner_dim,
156
+ config,
157
+ quant_config,
158
+ prefix=f"{prefix}.mlp")
159
+
160
+ def forward(
161
+ self,
162
+ hidden_states: torch.Tensor,
163
+ forward_batch: ForwardBatch,
164
+ ) -> torch.Tensor:
165
+ residual = hidden_states
166
+ hidden_states = self.ln_1(hidden_states)
167
+ attn_output = self.attn(
168
+ hidden_states=hidden_states,
169
+ forward_batch=forward_batch,
170
+ )
171
+ # residual connection
172
+ hidden_states = attn_output + residual
173
+
174
+ residual = hidden_states
175
+ hidden_states = self.ln_2(hidden_states)
176
+ feed_forward_hidden_states = self.mlp(hidden_states)
177
+ # residual connection
178
+ hidden_states = residual + feed_forward_hidden_states
179
+ return hidden_states
180
+
181
+
182
+
183
+ class GPT2Model(nn.Module):
184
+
185
+ def __init__(
186
+ self,
187
+ config: GPT2Config,
188
+ cache_config = None,
189
+ quant_config: Optional[QuantizationConfig] = None,
190
+ prefix: str = "",
191
+ ):
192
+ super().__init__()
193
+ self.config = config
194
+ assert not config.add_cross_attention
195
+ assert not config.scale_attn_by_inverse_layer_idx
196
+ assert not config.reorder_and_upcast_attn
197
+ self.embed_dim = config.hidden_size
198
+ self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
199
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
200
+ self.h = nn.ModuleList(
201
+ [
202
+ GPT2Block(i, config, cache_config, quant_config)
203
+ for i in range(config.num_hidden_layers)
204
+ ]
205
+ )
206
+
207
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
208
+
209
+ def forward(
210
+ self,
211
+ input_ids: torch.Tensor,
212
+ position_ids: torch.Tensor,
213
+ forward_batch: ForwardBatch,
214
+ ) -> torch.Tensor:
215
+ inputs_embeds = self.wte(input_ids)
216
+ position_embeds = self.wpe(position_ids)
217
+ hidden_states = inputs_embeds + position_embeds
218
+
219
+ for i in range(len(self.h)):
220
+ layer = self.h[i]
221
+ hidden_states = layer(hidden_states, forward_batch)
222
+
223
+ hidden_states = self.ln_f(hidden_states)
224
+ return hidden_states
225
+
226
+
227
+ class GPT2LMHeadModel(nn.Module):
228
+
229
+ def __init__(
230
+ self,
231
+ config: GPT2Config,
232
+ cache_config = None,
233
+ quant_config: Optional[QuantizationConfig] = None,
234
+ ):
235
+ super().__init__()
236
+ self.config = config
237
+ self.quant_config = quant_config
238
+ self.transformer = GPT2Model(config,
239
+ cache_config,
240
+ quant_config,
241
+ prefix="transformer")
242
+ self.lm_head = self.transformer.wte
243
+
244
+ self.logits_processor = LogitsProcessor(config)
245
+
246
+ def forward(
247
+ self,
248
+ input_ids: torch.Tensor,
249
+ positions: torch.Tensor,
250
+ forward_batch: ForwardBatch,
251
+ ) -> torch.Tensor:
252
+ hidden_states = self.transformer(input_ids, positions, forward_batch)
253
+ return self.logits_processor(
254
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
255
+ )
256
+
257
+
258
+
259
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
260
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
261
+ for name, loaded_weight in weights:
262
+ if "lm_head.weight" in name:
263
+ # GPT-2 ties the weights of the embedding layer and the final
264
+ # linear layer.
265
+ continue
266
+ if ".attn.bias" in name or ".attn.masked_bias" in name:
267
+ # Skip attention mask.
268
+ # NOTE: "c_attn.bias" should not be skipped.
269
+ continue
270
+ if not name.startswith("transformer."):
271
+ name = "transformer." + name
272
+
273
+ param = params_dict[name]
274
+ # The HF's GPT-2 implementation uses Conv1D instead of Linear.
275
+ # Because of this, we need to transpose the weights.
276
+ # Note(zhuohan): the logic below might break quantized models.
277
+ for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
278
+ if conv1d_weight_name not in name:
279
+ continue
280
+ if not name.endswith(".weight"):
281
+ continue
282
+ loaded_weight = loaded_weight.t()
283
+ weight_loader = getattr(param, "weight_loader",
284
+ default_weight_loader)
285
+ weight_loader(param, loaded_weight)
286
+
287
+ EntryClass = GPT2LMHeadModel
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
24
  from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
27
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
27
 
29
28
  from sglang.srt.layers.activation import get_act_fn
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
 
40
40
 
sglang/srt/models/grok.py CHANGED
@@ -28,10 +28,6 @@ from vllm.distributed import (
28
28
  get_tensor_model_parallel_world_size,
29
29
  )
30
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.layers.vocab_parallel_embedding import (
32
- ParallelLMHead,
33
- VocabParallelEmbedding,
34
- )
35
31
  from vllm.model_executor.model_loader.loader import DefaultModelLoader
36
32
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
33
 
@@ -45,6 +41,10 @@ from sglang.srt.layers.linear import (
45
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead,
46
+ VocabParallelEmbedding,
47
+ )
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
@@ -23,10 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.layers.vocab_parallel_embedding import (
27
- ParallelLMHead,
28
- VocabParallelEmbedding,
29
- )
30
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
27
 
32
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
39
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
 
@@ -24,10 +24,6 @@ from torch import nn
24
24
  from transformers import LlamaConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import (
28
- ParallelLMHead,
29
- VocabParallelEmbedding,
30
- )
31
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
28
 
33
29
  from sglang.srt.layers.activation import SiluAndMul
@@ -38,9 +34,14 @@ from sglang.srt.layers.linear import (
38
34
  RowParallelLinear,
39
35
  )
40
36
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
37
+ from sglang.srt.layers.pooler import Pooler, PoolingType
41
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
39
  from sglang.srt.layers.radix_attention import RadixAttention
43
40
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
41
+ from sglang.srt.layers.vocab_parallel_embedding import (
42
+ ParallelLMHead,
43
+ VocabParallelEmbedding,
44
+ )
44
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
47
 
@@ -303,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
303
304
  self.model = LlamaModel(config, quant_config=quant_config)
304
305
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
305
306
  self.logits_processor = LogitsProcessor(config)
307
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
306
308
 
307
309
  @torch.no_grad()
308
310
  def forward(
@@ -311,11 +313,15 @@ class LlamaForCausalLM(nn.Module):
311
313
  positions: torch.Tensor,
312
314
  forward_batch: ForwardBatch,
313
315
  input_embeds: torch.Tensor = None,
316
+ get_embedding: bool = False,
314
317
  ) -> LogitsProcessorOutput:
315
318
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
316
- return self.logits_processor(
317
- input_ids, hidden_states, self.lm_head.weight, forward_batch
318
- )
319
+ if not get_embedding:
320
+ return self.logits_processor(
321
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
322
+ )
323
+ else:
324
+ return self.pooler(hidden_states, forward_batch)
319
325
 
320
326
  def get_hidden_dim(self, module_name):
321
327
  # return input_dim, output_dim
@@ -409,11 +415,13 @@ class LlamaForCausalLM(nn.Module):
409
415
  if (
410
416
  hasattr(self.config, "tie_word_embeddings")
411
417
  and self.config.tie_word_embeddings
418
+ and "lm_head.weight" in params_dict
412
419
  ):
413
420
  # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
414
421
  param = self.lm_head.weight
415
422
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
416
423
  weight_loader(param, self.model.embed_tokens.weight)
424
+
417
425
  apply_torchao_config_(self, params_dict, set(["proj.weight"]))
418
426
 
419
427
 
@@ -36,9 +36,7 @@ class LlamaEmbeddingModel(nn.Module):
36
36
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
37
37
  return self.pooler(hidden_states, forward_batch)
38
38
 
39
- def load_weights(
40
- self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
41
- ):
39
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
42
40
  stacked_params_mapping = [
43
41
  # (param_name, shard_name, shard_id)
44
42
  ("qkv_proj", "q_proj", "q"),
@@ -49,7 +47,7 @@ class LlamaEmbeddingModel(nn.Module):
49
47
  ]
50
48
  params_dict = dict(self.model.named_parameters())
51
49
 
52
- def load_weights_per_param(name, loaded_weight):
50
+ for name, loaded_weight in weights:
53
51
  if "rotary_emb.inv_freq" in name or "projector" in name:
54
52
  return
55
53
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -78,12 +76,6 @@ class LlamaEmbeddingModel(nn.Module):
78
76
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
79
77
  weight_loader(param, loaded_weight)
80
78
 
81
- if name is None or loaded_weight is None:
82
- for name, loaded_weight in weights:
83
- load_weights_per_param(name, loaded_weight)
84
- else:
85
- load_weights_per_param(name, loaded_weight)
86
-
87
79
 
88
80
  class MistralModel(LlamaEmbeddingModel):
89
81
  pass
@@ -52,7 +52,12 @@ class LlamaForSequenceClassification(nn.Module):
52
52
  positions: torch.Tensor,
53
53
  forward_batch: ForwardBatch,
54
54
  input_embeds: torch.Tensor = None,
55
+ get_embedding: bool = True,
55
56
  ) -> EmbeddingPoolerOutput:
57
+ assert (
58
+ get_embedding
59
+ ), "LlamaForSequenceClassification is only used for embedding"
60
+
56
61
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
57
62
  scores = self.score(hidden_states)
58
63
 
@@ -22,10 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.layers.vocab_parallel_embedding import (
26
- ParallelLMHead,
27
- VocabParallelEmbedding,
28
- )
29
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
26
 
31
27
  from sglang.srt.layers.activation import SiluAndMul
@@ -38,6 +34,10 @@ from sglang.srt.layers.linear import (
38
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
 
43
43
 
@@ -29,10 +29,6 @@ from vllm.model_executor.layers.linear import (
29
29
  RowParallelLinear,
30
30
  )
31
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
- from vllm.model_executor.layers.vocab_parallel_embedding import (
33
- ParallelLMHead,
34
- VocabParallelEmbedding,
35
- )
36
32
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
33
 
38
34
  from sglang.srt.layers.activation import SiluAndMul
@@ -40,6 +36,10 @@ from sglang.srt.layers.layernorm import RMSNorm
40
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead,
41
+ VocabParallelEmbedding,
42
+ )
43
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
45
  from sglang.srt.utils import is_flashinfer_available
@@ -24,11 +24,6 @@ from transformers import MixtralConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.fused_moe import FusedMoE
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import (
28
- DEFAULT_VOCAB_PADDING_SIZE,
29
- ParallelLMHead,
30
- VocabParallelEmbedding,
31
- )
32
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
28
 
34
29
  from sglang.srt.layers.layernorm import RMSNorm
@@ -41,6 +36,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
41
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
37
  from sglang.srt.layers.radix_attention import RadixAttention
43
38
  from sglang.srt.layers.torchao_utils import apply_torchao_config_
39
+ from sglang.srt.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead,
41
+ VocabParallelEmbedding,
42
+ )
44
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
45
 
@@ -369,6 +368,9 @@ class MixtralForCausalLM(nn.Module):
369
368
  # Skip loading extra bias for GPTQ models.
370
369
  if name.endswith(".bias") and name not in params_dict:
371
370
  continue
371
+ # Skip loading kv_scale from ckpts towards new design.
372
+ if name.endswith(".kv_scale") and name not in params_dict:
373
+ continue
372
374
  if name is None:
373
375
  continue
374
376
 
@@ -29,10 +29,6 @@ from vllm.distributed import (
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
- from vllm.model_executor.layers.vocab_parallel_embedding import (
33
- ParallelLMHead,
34
- VocabParallelEmbedding,
35
- )
36
32
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
33
 
38
34
  from sglang.srt.layers.layernorm import RMSNorm
@@ -44,6 +40,10 @@ from sglang.srt.layers.linear import (
44
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.vocab_parallel_embedding import (
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
48
 
49
49
 
@@ -15,11 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
15
15
  _prepare_aspect_ratio_attention_mask,
16
16
  )
17
17
  from vllm.distributed import get_tensor_model_parallel_world_size
18
- from vllm.model_executor.layers.vocab_parallel_embedding import (
19
- DEFAULT_VOCAB_PADDING_SIZE,
20
- ParallelLMHead,
21
- VocabParallelEmbedding,
22
- )
23
18
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
19
 
25
20
  from sglang.srt.layers.activation import get_act_fn
@@ -32,6 +27,11 @@ from sglang.srt.layers.linear import (
32
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
28
  from sglang.srt.layers.quantization import QuantizationConfig
34
29
  from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.layers.vocab_parallel_embedding import (
31
+ DEFAULT_VOCAB_PADDING_SIZE,
32
+ ParallelLMHead,
33
+ VocabParallelEmbedding,
34
+ )
35
35
  from sglang.srt.managers.schedule_batch import ImageInputs
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
37
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
sglang/srt/models/olmo.py CHANGED
@@ -23,10 +23,6 @@ from torch import nn
23
23
  from transformers import OlmoConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.layers.vocab_parallel_embedding import (
27
- ParallelLMHead,
28
- VocabParallelEmbedding,
29
- )
30
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
27
 
32
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -38,6 +34,10 @@ from sglang.srt.layers.linear import (
38
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
 
43
43
 
@@ -35,10 +35,6 @@ from vllm.model_executor.layers.linear import (
35
35
  RowParallelLinear,
36
36
  )
37
37
  from vllm.model_executor.layers.rotary_embedding import get_rope
38
- from vllm.model_executor.layers.vocab_parallel_embedding import (
39
- ParallelLMHead,
40
- VocabParallelEmbedding,
41
- )
42
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
39
  from vllm.utils import print_warning_once
44
40
 
@@ -47,6 +43,10 @@ from sglang.srt.layers.layernorm import RMSNorm
47
43
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
48
44
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
45
  from sglang.srt.layers.radix_attention import RadixAttention
46
+ from sglang.srt.layers.vocab_parallel_embedding import (
47
+ ParallelLMHead,
48
+ VocabParallelEmbedding,
49
+ )
50
50
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
51
51
 
52
52
 
sglang/srt/models/qwen.py CHANGED
@@ -22,10 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.layers.vocab_parallel_embedding import (
26
- ParallelLMHead,
27
- VocabParallelEmbedding,
28
- )
29
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
26
 
31
27
  from sglang.srt.layers.activation import SiluAndMul
@@ -38,6 +34,10 @@ from sglang.srt.layers.linear import (
38
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.vocab_parallel_embedding import (
38
+ ParallelLMHead,
39
+ VocabParallelEmbedding,
40
+ )
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
 
43
43
 
@@ -22,10 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.layers.vocab_parallel_embedding import (
26
- ParallelLMHead,
27
- VocabParallelEmbedding,
28
- )
29
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
26
 
31
27
  from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
39
35
  from sglang.srt.layers.pooler import Pooler, PoolingType
40
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
  Qwen2Config = None