sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
1
1
  from typing import Optional
2
2
 
3
+ from transformers import PretrainedConfig
4
+
3
5
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
4
6
 
5
7
 
@@ -16,9 +18,13 @@ class ModelConfig:
16
18
  self.trust_remote_code = trust_remote_code
17
19
  self.revision = revision
18
20
  self.model_overide_args = model_overide_args
19
- self.hf_config = get_config(self.path, trust_remote_code, revision,
20
- model_overide_args=model_overide_args)
21
-
21
+ self.hf_config = get_config(
22
+ self.path,
23
+ trust_remote_code,
24
+ revision,
25
+ model_overide_args=model_overide_args,
26
+ )
27
+ self.hf_text_config = get_hf_text_config(self.hf_config)
22
28
  if context_length is not None:
23
29
  self.context_len = context_length
24
30
  else:
@@ -43,4 +49,77 @@ class ModelConfig:
43
49
  self.num_key_value_heads = self.num_attention_heads
44
50
  self.hidden_size = self.hf_config.hidden_size
45
51
  self.num_hidden_layers = self.hf_config.num_hidden_layers
46
- self.vocab_size = self.hf_config.vocab_size
52
+ self.vocab_size = self.hf_config.vocab_size
53
+
54
+ # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
55
+ def get_total_num_kv_heads(self) -> int:
56
+ """Returns the total number of KV heads."""
57
+ # For GPTBigCode & Falcon:
58
+ # NOTE: for falcon, when new_decoder_architecture is True, the
59
+ # multi_query flag is ignored and we use n_head_kv for the number of
60
+ # KV heads.
61
+ falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
62
+ new_decoder_arch_falcon = (
63
+ self.hf_config.model_type in falcon_model_types
64
+ and getattr(self.hf_config, "new_decoder_architecture", False)
65
+ )
66
+ if not new_decoder_arch_falcon and getattr(
67
+ self.hf_text_config, "multi_query", False
68
+ ):
69
+ # Multi-query attention, only one KV head.
70
+ # Currently, tensor parallelism is not supported in this case.
71
+ return 1
72
+
73
+ # For DBRX and MPT
74
+ if self.hf_config.model_type in ["mpt"]:
75
+ if "kv_n_heads" in self.hf_config.attn_config:
76
+ return self.hf_config.attn_config["kv_n_heads"]
77
+ return self.hf_config.num_attention_heads
78
+ if self.hf_config.model_type in ["dbrx"]:
79
+ return getattr(
80
+ self.hf_config.attn_config,
81
+ "kv_n_heads",
82
+ self.hf_config.num_attention_heads,
83
+ )
84
+
85
+ attributes = [
86
+ # For Falcon:
87
+ "n_head_kv",
88
+ "num_kv_heads",
89
+ # For LLaMA-2:
90
+ "num_key_value_heads",
91
+ # For ChatGLM:
92
+ "multi_query_group_num",
93
+ ]
94
+ for attr in attributes:
95
+ num_kv_heads = getattr(self.hf_text_config, attr, None)
96
+ if num_kv_heads is not None:
97
+ return num_kv_heads
98
+
99
+ # For non-grouped-query attention models, the number of KV heads is
100
+ # equal to the number of attention heads.
101
+ return self.hf_text_config.num_attention_heads
102
+
103
+ # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
104
+ def get_num_kv_heads(self, tensor_parallel_size) -> int:
105
+ """Returns the number of KV heads per GPU."""
106
+ total_num_kv_heads = self.get_total_num_kv_heads()
107
+ # If tensor parallelism is used, we divide the number of KV heads by
108
+ # the tensor parallel size. We will replicate the KV heads in the
109
+ # case where the number of KV heads is smaller than the tensor
110
+ # parallel size so each GPU has at least one KV head.
111
+ return max(1, total_num_kv_heads // tensor_parallel_size)
112
+
113
+
114
+ def get_hf_text_config(config: PretrainedConfig):
115
+ """Get the "sub" config relevant to llm for multi modal models.
116
+ No op for pure text models.
117
+ """
118
+ if hasattr(config, "text_config"):
119
+ # The code operates under the assumption that text_config should have
120
+ # `num_attention_heads` (among others). Assert here to fail early
121
+ # if transformers config doesn't align with this assumption.
122
+ assert hasattr(config.text_config, "num_attention_heads")
123
+ return config.text_config
124
+ else:
125
+ return config
@@ -0,0 +1,399 @@
1
+ # coding=utf-8
2
+ # Adapted from
3
+ # https://github.com/THUDM/ChatGLM2-6B
4
+ """Inference-only ChatGLM model compatible with THUDM weights."""
5
+ from typing import Iterable, List, Optional, Tuple
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import LayerNorm
10
+ from vllm.config import CacheConfig
11
+ from vllm.distributed import get_tensor_model_parallel_world_size
12
+ from vllm.model_executor.layers.activation import SiluAndMul
13
+ from vllm.model_executor.layers.layernorm import RMSNorm
14
+ from vllm.model_executor.layers.linear import (
15
+ MergedColumnParallelLinear,
16
+ QKVParallelLinear,
17
+ RowParallelLinear,
18
+ )
19
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
20
+ from vllm.model_executor.layers.rotary_embedding import get_rope
21
+ from vllm.model_executor.layers.sampler import Sampler
22
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
23
+ ParallelLMHead,
24
+ VocabParallelEmbedding,
25
+ )
26
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
+ from vllm.model_executor.sampling_metadata import SamplingMetadata
28
+ from vllm.sequence import SamplerOutput
29
+ from vllm.transformers_utils.configs import ChatGLMConfig
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
34
+
35
+ LoraConfig = None
36
+
37
+
38
+ class GLMAttention(nn.Module):
39
+ def __init__(
40
+ self,
41
+ config,
42
+ layer_id: int = 0,
43
+ cache_config: Optional[CacheConfig] = None,
44
+ quant_config: Optional[QuantizationConfig] = None,
45
+ ):
46
+ super().__init__()
47
+ self.hidden_size = config.hidden_size
48
+ tp_size = get_tensor_model_parallel_world_size()
49
+ self.total_num_heads = config.num_attention_heads
50
+ assert self.total_num_heads % tp_size == 0
51
+ self.num_heads = self.total_num_heads // tp_size
52
+ self.multi_query_attention = config.multi_query_attention
53
+ self.total_num_kv_heads = (
54
+ config.multi_query_group_num
55
+ if config.multi_query_attention
56
+ else config.num_attention_heads
57
+ )
58
+ if self.total_num_kv_heads >= tp_size:
59
+ # Number of KV heads is greater than TP size, so we partition
60
+ # the KV heads across multiple tensor parallel GPUs.
61
+ assert self.total_num_kv_heads % tp_size == 0
62
+ else:
63
+ # Number of KV heads is less than TP size, so we replicate
64
+ # the KV heads across multiple tensor parallel GPUs.
65
+ assert tp_size % self.total_num_kv_heads == 0
66
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
67
+ self.head_dim = config.hidden_size // self.total_num_heads
68
+ self.q_size = self.num_heads * self.head_dim
69
+ self.kv_size = self.num_kv_heads * self.head_dim
70
+ self.scaling = self.head_dim**-0.5
71
+
72
+ self.query_key_value = QKVParallelLinear(
73
+ self.hidden_size,
74
+ self.head_dim,
75
+ self.total_num_heads,
76
+ self.total_num_kv_heads,
77
+ bias=config.add_bias_linear or config.add_qkv_bias,
78
+ quant_config=quant_config,
79
+ )
80
+ self.dense = RowParallelLinear(
81
+ self.total_num_heads * self.head_dim,
82
+ config.hidden_size,
83
+ bias=config.add_bias_linear,
84
+ quant_config=quant_config,
85
+ )
86
+
87
+ # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
88
+ rope_ratio = getattr(config, "rope_ratio", 1.0)
89
+ max_positions = getattr(config, "seq_length", 8192)
90
+ self.rotary_emb = get_rope(
91
+ self.head_dim,
92
+ rotary_dim=self.head_dim // 2,
93
+ max_position=max_positions,
94
+ base=10000 * rope_ratio,
95
+ is_neox_style=False,
96
+ )
97
+ self.attn = RadixAttention(
98
+ self.num_heads,
99
+ self.head_dim,
100
+ self.scaling,
101
+ num_kv_heads=self.num_kv_heads,
102
+ layer_id=layer_id,
103
+ )
104
+
105
+ def forward(
106
+ self,
107
+ hidden_states: torch.Tensor,
108
+ position_ids: torch.Tensor,
109
+ input_metadata: InputMetadata,
110
+ ) -> torch.Tensor:
111
+ qkv, _ = self.query_key_value(hidden_states)
112
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
113
+ q, k = self.rotary_emb(position_ids, q, k)
114
+ context_layer = self.attn(
115
+ q,
116
+ k,
117
+ v,
118
+ input_metadata,
119
+ )
120
+ attn_output, _ = self.dense(context_layer)
121
+ return attn_output
122
+
123
+
124
+ class GLMMLP(nn.Module):
125
+ """MLP.
126
+
127
+ MLP will take the input with h hidden state, project it to 4*h
128
+ hidden dimension, perform nonlinear transformation, and project the
129
+ state back into h hidden dimension.
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ config,
135
+ quant_config: Optional[QuantizationConfig] = None,
136
+ ):
137
+ super().__init__()
138
+
139
+ self.add_bias = config.add_bias_linear
140
+
141
+ # Project to 4h.
142
+ self.dense_h_to_4h = MergedColumnParallelLinear(
143
+ config.hidden_size,
144
+ [config.ffn_hidden_size] * 2,
145
+ bias=config.add_bias_linear,
146
+ quant_config=quant_config,
147
+ )
148
+
149
+ self.activation_func = SiluAndMul()
150
+
151
+ # Project back to h.
152
+ self.dense_4h_to_h = RowParallelLinear(
153
+ config.ffn_hidden_size,
154
+ config.hidden_size,
155
+ bias=config.add_bias_linear,
156
+ quant_config=quant_config,
157
+ )
158
+
159
+ def forward(self, hidden_states):
160
+ # [s, b, 4hp]
161
+ intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
162
+ intermediate_parallel = self.activation_func(intermediate_parallel)
163
+ # [s, b, h]
164
+ output, _ = self.dense_4h_to_h(intermediate_parallel)
165
+ return output
166
+
167
+
168
+ class GLMBlock(nn.Module):
169
+ """A single transformer layer.
170
+
171
+ Transformer layer takes input with size [s, b, h] and returns an
172
+ output of the same size.
173
+ """
174
+
175
+ def __init__(
176
+ self,
177
+ config,
178
+ layer_id: int,
179
+ cache_config: Optional[CacheConfig] = None,
180
+ quant_config: Optional[QuantizationConfig] = None,
181
+ ):
182
+ super().__init__()
183
+ self.apply_residual_connection_post_layernorm = (
184
+ config.apply_residual_connection_post_layernorm
185
+ )
186
+
187
+ self.fp32_residual_connection = config.fp32_residual_connection
188
+
189
+ layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
190
+ # Layernorm on the input data.
191
+ self.input_layernorm = layer_norm_func(
192
+ config.hidden_size, eps=config.layernorm_epsilon
193
+ )
194
+
195
+ # Self attention.
196
+ self.self_attention = GLMAttention(config, layer_id, cache_config, quant_config)
197
+ self.hidden_dropout = config.hidden_dropout
198
+
199
+ # Layernorm on the attention output
200
+ self.post_attention_layernorm = layer_norm_func(
201
+ config.hidden_size, eps=config.layernorm_epsilon
202
+ )
203
+
204
+ # MLP
205
+ self.mlp = GLMMLP(config, quant_config)
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ position_ids: torch.Tensor,
211
+ input_metadata: InputMetadata,
212
+ ) -> torch.Tensor:
213
+ # hidden_states: [num_tokens, h]
214
+ # Layer norm at the beginning of the transformer layer.
215
+ layernorm_output = self.input_layernorm(hidden_states)
216
+ # Self attention.
217
+ attention_output = self.self_attention(
218
+ hidden_states=layernorm_output,
219
+ position_ids=position_ids,
220
+ input_metadata=input_metadata,
221
+ )
222
+
223
+ # Residual connection.
224
+ if self.apply_residual_connection_post_layernorm:
225
+ residual = layernorm_output
226
+ else:
227
+ residual = hidden_states
228
+
229
+ layernorm_input = residual + attention_output
230
+
231
+ # Layer norm post the self attention.
232
+ layernorm_output = self.post_attention_layernorm(layernorm_input)
233
+
234
+ # Second residual connection.
235
+ if self.apply_residual_connection_post_layernorm:
236
+ residual = layernorm_output
237
+ else:
238
+ residual = layernorm_input
239
+
240
+ output = self.mlp(layernorm_output) + residual
241
+
242
+ return output
243
+
244
+
245
+ class GLMTransformer(nn.Module):
246
+ """Transformer class."""
247
+
248
+ def __init__(
249
+ self,
250
+ config,
251
+ cache_config: Optional[CacheConfig] = None,
252
+ quant_config: Optional[QuantizationConfig] = None,
253
+ ):
254
+ super().__init__()
255
+ self.post_layer_norm = config.post_layer_norm
256
+
257
+ # Number of layers.
258
+ self.num_layers = config.num_layers
259
+
260
+ # Transformer layers.
261
+ self.layers = nn.ModuleList(
262
+ [
263
+ GLMBlock(config, i, cache_config, quant_config)
264
+ for i in range(self.num_layers)
265
+ ]
266
+ )
267
+
268
+ if self.post_layer_norm:
269
+ layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
270
+ # Final layer norm before output.
271
+ self.final_layernorm = layer_norm_func(
272
+ config.hidden_size, eps=config.layernorm_epsilon
273
+ )
274
+
275
+ def forward(
276
+ self,
277
+ hidden_states: torch.Tensor,
278
+ position_ids: torch.Tensor,
279
+ input_metadata: InputMetadata,
280
+ ) -> torch.Tensor:
281
+ for i in range(self.num_layers):
282
+ layer = self.layers[i]
283
+ hidden_states = layer(
284
+ hidden_states=hidden_states,
285
+ position_ids=position_ids,
286
+ input_metadata=input_metadata,
287
+ )
288
+ # Final layer norm.
289
+ if self.post_layer_norm:
290
+ hidden_states = self.final_layernorm(hidden_states)
291
+
292
+ return hidden_states
293
+
294
+
295
+ class ChatGLMModel(nn.Module):
296
+ def __init__(
297
+ self,
298
+ config,
299
+ cache_config: Optional[CacheConfig] = None,
300
+ quant_config: Optional[QuantizationConfig] = None,
301
+ ):
302
+ super().__init__()
303
+
304
+ self.embedding = VocabParallelEmbedding(
305
+ config.padded_vocab_size, config.hidden_size
306
+ )
307
+
308
+ self.num_layers = config.num_layers
309
+ self.multi_query_group_num = config.multi_query_group_num
310
+ self.kv_channels = config.kv_channels
311
+ self.encoder = GLMTransformer(config, cache_config, quant_config)
312
+
313
+ self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
314
+
315
+ def forward(
316
+ self,
317
+ input_ids: torch.Tensor,
318
+ position_ids: torch.Tensor,
319
+ input_metadata: InputMetadata,
320
+ ) -> torch.Tensor:
321
+ inputs_embeds = self.embedding(input_ids)
322
+
323
+ # Run encoder.
324
+ hidden_states = self.encoder(
325
+ hidden_states=inputs_embeds,
326
+ position_ids=position_ids,
327
+ input_metadata=input_metadata,
328
+ )
329
+ return hidden_states
330
+
331
+
332
+ class ChatGLMForCausalLM(nn.Module):
333
+ packed_modules_mapping = {
334
+ "query_key_value": ["query_key_value"],
335
+ "dense_h_to_4h": ["dense_h_to_4h"],
336
+ }
337
+ # LoRA specific attributes
338
+ supported_lora_modules = [
339
+ "query_key_value",
340
+ "dense",
341
+ "dense_h_to_4h",
342
+ "dense_4h_to_h",
343
+ ]
344
+ embedding_modules = {}
345
+ embedding_padding_modules = []
346
+
347
+ def __init__(
348
+ self,
349
+ config: ChatGLMConfig,
350
+ cache_config: Optional[CacheConfig] = None,
351
+ quant_config: Optional[QuantizationConfig] = None,
352
+ lora_config: Optional[LoraConfig] = None,
353
+ ):
354
+ super().__init__()
355
+ self.config: ChatGLMConfig = config
356
+ self.quant_config = quant_config
357
+ self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
358
+ self.transformer = ChatGLMModel(config, cache_config, quant_config)
359
+ self.lm_head = self.transformer.output_layer
360
+ self.logits_processor = LogitsProcessor(config)
361
+ self.sampler = Sampler()
362
+
363
+ def forward(
364
+ self,
365
+ input_ids: torch.Tensor,
366
+ positions: torch.Tensor,
367
+ input_metadata: InputMetadata,
368
+ ) -> torch.Tensor:
369
+ hidden_states = self.transformer(input_ids, positions, input_metadata)
370
+ return self.logits_processor(
371
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
372
+ )
373
+
374
+ def sample(
375
+ self,
376
+ logits: torch.Tensor,
377
+ sampling_metadata: SamplingMetadata,
378
+ ) -> Optional[SamplerOutput]:
379
+ next_tokens = self.sampler(logits, sampling_metadata)
380
+ return next_tokens
381
+
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
383
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
384
+ for name, loaded_weight in weights:
385
+ if "rotary_pos_emb.inv_freq" in name:
386
+ continue
387
+ if "word_embeddings" in name:
388
+ name = name.replace(".word_embeddings", "")
389
+ # Skip loading extra bias for GPTQ models.
390
+ if name.endswith(".bias") and name not in params_dict:
391
+ continue
392
+ param = params_dict[name]
393
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
394
+ weight_loader(param, loaded_weight)
395
+
396
+
397
+ EntryClass = ChatGLMForCausalLM
398
+ # compat: glm model.config class == ChatGLMModel
399
+ EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)]
@@ -23,7 +23,7 @@
23
23
 
24
24
  # This file is based on the LLama model definition file in transformers
25
25
  """PyTorch Cohere model."""
26
- from typing import Optional, Tuple, Iterable
26
+ from typing import Iterable, Optional, Tuple
27
27
 
28
28
  import torch
29
29
  import torch.utils.checkpoint
@@ -44,8 +44,8 @@ from vllm.model_executor.layers.linear import (
44
44
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
45
45
  from vllm.model_executor.layers.rotary_embedding import get_rope
46
46
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
47
- from vllm.model_executor.utils import set_weight_attrs
48
47
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
+ from vllm.model_executor.utils import set_weight_attrs
49
49
 
50
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
51
51
  from sglang.srt.layers.radix_attention import RadixAttention
sglang/srt/models/dbrx.py CHANGED
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
24
24
  ParallelLMHead,
25
25
  VocabParallelEmbedding,
26
26
  )
27
- from vllm.model_executor.utils import set_weight_attrs
28
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
+ from vllm.model_executor.utils import set_weight_attrs
29
29
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
30
30
 
31
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -6,7 +6,7 @@ from typing import Iterable, Optional, Tuple
6
6
  import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
- from vllm.config import LoRAConfig, CacheConfig
9
+ from vllm.config import CacheConfig, LoRAConfig
10
10
  from vllm.distributed import get_tensor_model_parallel_world_size
11
11
  from vllm.model_executor.layers.activation import GeluAndMul
12
12
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -310,6 +310,10 @@ class GemmaForCausalLM(nn.Module):
310
310
  weight_loader(param, loaded_weight, shard_id)
311
311
  break
312
312
  else:
313
+ # lm_head is not used in vllm as it is tied with embed_token.
314
+ # To prevent errors, skip loading lm_head.weight.
315
+ if "lm_head.weight" in name:
316
+ continue
313
317
  # Skip loading extra bias for GPTQ models.
314
318
  if name.endswith(".bias") and name not in params_dict:
315
319
  continue