sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,391 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from
16
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/olmo2.py
17
+ """Inference-only OLMo2 model compatible with HuggingFace weights."""
18
+ from functools import partial
19
+ from typing import Iterable, Optional, Tuple
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import PretrainedConfig
24
+ from vllm.distributed import (
25
+ get_tensor_model_parallel_rank,
26
+ get_tensor_model_parallel_world_size,
27
+ split_tensor_along_last_dim,
28
+ tensor_model_parallel_all_gather,
29
+ )
30
+ from vllm.model_executor.layers.rotary_embedding import get_rope
31
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
+
33
+ from sglang.srt.layers.activation import SiluAndMul
34
+ from sglang.srt.layers.layernorm import RMSNorm
35
+ from sglang.srt.layers.linear import (
36
+ MergedColumnParallelLinear,
37
+ QKVParallelLinear,
38
+ RowParallelLinear,
39
+ )
40
+ from sglang.srt.layers.logits_processor import LogitsProcessor
41
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
+ from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.vocab_parallel_embedding import (
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.utils import make_layers
49
+
50
+
51
+ class Olmo2Attention(nn.Module):
52
+ """
53
+ This is the attention block where the output is computed as
54
+ ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
55
+ (plus another skip connection).
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ config: PretrainedConfig,
61
+ layer_id: int = 0,
62
+ quant_config: Optional[QuantizationConfig] = None,
63
+ ):
64
+ super().__init__()
65
+ self.config = config
66
+ self.hidden_size = config.hidden_size
67
+ tp_size = get_tensor_model_parallel_world_size()
68
+ self.total_num_heads = config.num_attention_heads
69
+
70
+ assert self.hidden_size % self.total_num_heads == 0
71
+ assert self.total_num_heads % tp_size == 0
72
+
73
+ self.num_heads = self.total_num_heads // tp_size
74
+ self.total_num_kv_heads = self.config.num_key_value_heads
75
+
76
+ if self.total_num_kv_heads >= tp_size:
77
+ # Number of KV heads is greater than TP size, so we partition
78
+ # the KV heads across multiple tensor parallel GPUs.
79
+ assert self.total_num_kv_heads % tp_size == 0
80
+ else:
81
+ # Number of KV heads is less than TP size, so we replicate
82
+ # the KV heads across multiple tensor parallel GPUs.
83
+ assert tp_size % self.total_num_kv_heads == 0
84
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
85
+
86
+ self.head_dim = self.hidden_size // self.total_num_heads
87
+ self.max_position_embeddings = config.max_position_embeddings
88
+ self.rope_theta = config.rope_theta
89
+
90
+ # Attention input projection. Projects x -> (q, k, v)
91
+ self.qkv_proj = QKVParallelLinear(
92
+ self.hidden_size,
93
+ self.head_dim,
94
+ self.total_num_heads,
95
+ bias=config.attention_bias,
96
+ )
97
+ self.tp_rank = get_tensor_model_parallel_rank()
98
+
99
+ self.k_norm = RMSNorm(
100
+ self.total_num_kv_heads * self.head_dim,
101
+ eps=self.config.rms_norm_eps,
102
+ )
103
+ self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
104
+ # Rotary embeddings.
105
+ self.rotary_emb = get_rope(
106
+ self.head_dim,
107
+ rotary_dim=self.head_dim,
108
+ max_position=self.max_position_embeddings,
109
+ base=self.rope_theta,
110
+ )
111
+ self.scaling = self.head_dim**-0.5
112
+ self.attn = RadixAttention(
113
+ self.num_heads,
114
+ self.head_dim,
115
+ self.scaling,
116
+ num_kv_heads=self.num_kv_heads,
117
+ layer_id=layer_id,
118
+ )
119
+
120
+ # Attention output projection.
121
+ self.o_proj = RowParallelLinear(
122
+ self.head_dim * self.total_num_heads,
123
+ self.hidden_size,
124
+ bias=config.attention_bias,
125
+ )
126
+
127
+ def _apply_qk_norm(
128
+ self, q: torch.Tensor, k: torch.Tensor
129
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
130
+ if self.tp_size > 1:
131
+ q = tensor_model_parallel_all_gather(q.contiguous())
132
+ k = tensor_model_parallel_all_gather(k.contiguous())
133
+ q = self.q_norm.forward_native(q)
134
+ k = self.k_norm.forward_native(k)
135
+ if self.tp_size > 1:
136
+ splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
137
+ q = splitter(q)[self.tp_rank]
138
+ k = splitter(k)[self.tp_rank]
139
+ return q, k
140
+
141
+ def forward(
142
+ self,
143
+ positions: torch.Tensor,
144
+ hidden_states: torch.Tensor,
145
+ forward_batch: ForwardBatch,
146
+ ) -> torch.Tensor:
147
+ qkv, _ = self.qkv_proj(hidden_states)
148
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
149
+ q, k = self._apply_qk_norm(q, k)
150
+ q, k = self.rotary_emb(positions, q, k)
151
+ attn_output = self.attn(q, k, v, forward_batch)
152
+ output, _ = self.o_proj(attn_output)
153
+ return output
154
+
155
+
156
+ class Olmo2MLP(nn.Module):
157
+ """
158
+ This is the MLP block where the output is computed as
159
+ ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))``
160
+ (plus another skip connection).
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ config: PretrainedConfig,
166
+ quant_config: Optional[QuantizationConfig] = None,
167
+ ):
168
+ super().__init__()
169
+ self.config = config
170
+ self.hidden_size = config.hidden_size
171
+ self.intermediate_size = config.intermediate_size
172
+
173
+ # Feed-forward input projection.
174
+ self.gate_up_proj = MergedColumnParallelLinear(
175
+ self.hidden_size,
176
+ [self.intermediate_size] * 2,
177
+ bias=False,
178
+ quant_config=quant_config,
179
+ )
180
+
181
+ # Activation function.
182
+ self.act_fn = SiluAndMul()
183
+
184
+ # Feed-forward output projection.
185
+ self.down_proj = RowParallelLinear(
186
+ self.intermediate_size,
187
+ self.hidden_size,
188
+ bias=False,
189
+ quant_config=quant_config,
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ ) -> torch.Tensor:
196
+ gate_up, _ = self.gate_up_proj(x)
197
+ x = self.act_fn(gate_up)
198
+ x, _ = self.down_proj(x)
199
+ return x
200
+
201
+
202
+ class Olmo2DecoderLayer(nn.Module):
203
+ """
204
+ This is a typical transformer block where the output is
205
+ computed as ``MLP(LN(x + Attention(LN(x))))``
206
+ (plus another skip connection).
207
+ """
208
+
209
+ def __init__(
210
+ self,
211
+ config: PretrainedConfig,
212
+ layer_id: int = 0,
213
+ quant_config: Optional[QuantizationConfig] = None,
214
+ ):
215
+ super().__init__()
216
+ # Attention block.
217
+ self.self_attn = Olmo2Attention(config, layer_id, quant_config)
218
+
219
+ # MLP block.
220
+ self.mlp = Olmo2MLP(config, quant_config)
221
+
222
+ # RMSNorm
223
+ self.post_attention_layernorm = RMSNorm(
224
+ config.hidden_size, eps=config.rms_norm_eps
225
+ )
226
+
227
+ self.post_feedforward_layernorm = RMSNorm(
228
+ config.hidden_size, eps=config.rms_norm_eps
229
+ )
230
+
231
+ def forward(
232
+ self,
233
+ positions: torch.Tensor,
234
+ hidden_states: torch.Tensor,
235
+ forward_batch: ForwardBatch,
236
+ ) -> torch.Tensor:
237
+ # Attention block.
238
+ residual = hidden_states
239
+ hidden_states = self.self_attn(positions, hidden_states, forward_batch)
240
+ hidden_states = self.post_attention_layernorm(hidden_states)
241
+ hidden_states = hidden_states + residual
242
+
243
+ # MLP block.
244
+ residual = hidden_states
245
+ hidden_states = self.mlp(hidden_states)
246
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
247
+ hidden_states = residual + hidden_states
248
+ return hidden_states
249
+
250
+
251
+ class Olmo2Model(nn.Module):
252
+
253
+ def __init__(
254
+ self,
255
+ config: PretrainedConfig,
256
+ quant_config: Optional[QuantizationConfig] = None,
257
+ ):
258
+ super().__init__()
259
+ self.config = config
260
+
261
+ self.embed_tokens = VocabParallelEmbedding(
262
+ config.vocab_size, config.hidden_size
263
+ )
264
+ self.layers = make_layers(
265
+ config.num_hidden_layers,
266
+ lambda idx, prefix: Olmo2DecoderLayer(
267
+ layer_id=idx,
268
+ config=config,
269
+ quant_config=quant_config,
270
+ ),
271
+ )
272
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
+
274
+ def forward(
275
+ self,
276
+ input_ids: torch.Tensor,
277
+ positions: torch.Tensor,
278
+ forward_batch: ForwardBatch,
279
+ input_embeds: torch.Tensor = None,
280
+ ) -> torch.Tensor:
281
+ """
282
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
283
+ """
284
+ # Get embeddings of input.
285
+ # shape: (batch_size, seq_len, d_model)
286
+
287
+ if input_embeds is None:
288
+ hidden_states = self.embed_tokens(input_ids)
289
+ else:
290
+ hidden_states = input_embeds
291
+
292
+ # Apply blocks one-by-one.
293
+ for layer_id, decoder_layer in enumerate(self.layers):
294
+ # shape: (batch_size, seq_len, d_model)
295
+ hidden_states = decoder_layer(
296
+ positions,
297
+ hidden_states,
298
+ forward_batch,
299
+ )
300
+
301
+ # Apply final layer norm.
302
+ # shape: (batch_size, seq_len or 1, d_model)
303
+ hidden_states = self.norm(hidden_states)
304
+ return hidden_states
305
+
306
+
307
+ class Olmo2ForCausalLM(nn.Module):
308
+ """
309
+ Extremely barebones HF model wrapper.
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ config: PretrainedConfig,
315
+ quant_config: Optional[QuantizationConfig] = None,
316
+ ):
317
+ super().__init__()
318
+ self.config = config
319
+ self.model = Olmo2Model(config, quant_config)
320
+ if config.tie_word_embeddings:
321
+ self.lm_head = self.model.embed_tokens
322
+ else:
323
+ self.unpadded_vocab_size = config.vocab_size
324
+ self.lm_head = ParallelLMHead(
325
+ self.unpadded_vocab_size,
326
+ config.hidden_size,
327
+ org_num_embeddings=config.vocab_size,
328
+ quant_config=quant_config,
329
+ )
330
+ self.logits_processor = LogitsProcessor(config)
331
+
332
+ def forward(
333
+ self,
334
+ input_ids: torch.Tensor,
335
+ positions: torch.Tensor,
336
+ forward_batch: ForwardBatch,
337
+ input_embeds: torch.Tensor = None,
338
+ ) -> torch.Tensor:
339
+ hidden_states = self.model(
340
+ input_ids=input_ids,
341
+ positions=positions,
342
+ forward_batch=forward_batch,
343
+ input_embeds=input_embeds,
344
+ )
345
+ return self.logits_processor(
346
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
347
+ )
348
+
349
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
350
+ stacked_params_mapping = [
351
+ # (param_name, shard_name, shard_id)
352
+ ("qkv_proj", "q_proj", "q"),
353
+ ("qkv_proj", "k_proj", "k"),
354
+ ("qkv_proj", "v_proj", "v"),
355
+ ("gate_up_proj", "gate_proj", 0),
356
+ ("gate_up_proj", "up_proj", 1),
357
+ ]
358
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
359
+ for name, loaded_weight in weights:
360
+ if "rotary_emb.inv_freq" in name:
361
+ continue
362
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
363
+ # Models trained using ColossalAI may include these tensors in
364
+ # the checkpoint. Skip them.
365
+ continue
366
+ # With tie_word_embeddings, we can skip lm_head.weight
367
+ # The weight might appear unnecessarily in the files if the model is
368
+ # processed with quantization, LoRA, fine-tuning, etc.
369
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
370
+ continue
371
+ for param_name, weight_name, shard_id in stacked_params_mapping:
372
+ if weight_name not in name:
373
+ continue
374
+ name = name.replace(weight_name, param_name)
375
+ # Skip loading extra bias for GPTQ models.
376
+ if name.endswith(".bias") and name not in params_dict:
377
+ continue
378
+ param = params_dict[name]
379
+ weight_loader = param.weight_loader
380
+ weight_loader(param, loaded_weight, shard_id)
381
+ break
382
+ else:
383
+ # Skip loading extra bias for GPTQ models.
384
+ if name.endswith(".bias") and name not in params_dict:
385
+ continue
386
+ param = params_dict[name]
387
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
388
+ weight_loader(param, loaded_weight)
389
+
390
+
391
+ EntryClass = Olmo2ForCausalLM
@@ -34,8 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
38
- from vllm.utils import print_warning_once
39
37
 
40
38
  from sglang.srt.layers.activation import SiluAndMul
41
39
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,7 +46,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
46
  VocabParallelEmbedding,
49
47
  )
50
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
51
- from sglang.srt.utils import make_layers
49
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
50
+ from sglang.srt.utils import make_layers, print_warning_once
52
51
 
53
52
 
54
53
  class OlmoeMoE(nn.Module):
@@ -300,7 +299,6 @@ class OlmoeForCausalLM(nn.Module):
300
299
  def __init__(
301
300
  self,
302
301
  config: PretrainedConfig,
303
- cache_config=None,
304
302
  quant_config: Optional[QuantizationConfig] = None,
305
303
  ) -> None:
306
304
  super().__init__()
@@ -321,7 +319,7 @@ class OlmoeForCausalLM(nn.Module):
321
319
  ) -> torch.Tensor:
322
320
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
323
321
  return self.logits_processor(
324
- input_ids, hidden_states, self.lm_head.weight, forward_batch
322
+ input_ids, hidden_states, self.lm_head, forward_batch
325
323
  )
326
324
 
327
325
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -7,8 +7,6 @@ from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
8
  from vllm.distributed import get_tensor_model_parallel_world_size
9
9
  from vllm.model_executor.layers.rotary_embedding import get_rope
10
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
- from vllm.model_executor.models.utils import make_layers
12
10
 
13
11
  from sglang.srt.layers.linear import (
14
12
  MergedColumnParallelLinear,
@@ -27,6 +25,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
27
25
  )
28
26
  from sglang.srt.managers.schedule_batch import global_server_args_dict
29
27
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
29
+ from sglang.srt.utils import make_layers
30
30
 
31
31
 
32
32
  @torch.jit.script
@@ -235,7 +235,6 @@ class Phi3SmallDecoderLayer(nn.Module):
235
235
  self,
236
236
  config: PretrainedConfig,
237
237
  layer_id: int,
238
- cache_config=None,
239
238
  quant_config: Optional[QuantizationConfig] = None,
240
239
  ):
241
240
  super().__init__()
@@ -286,7 +285,6 @@ class Phi3SmallModel(nn.Module):
286
285
  super().__init__()
287
286
 
288
287
  self.config = config
289
- cache_config = None
290
288
  self.embed_tokens = VocabParallelEmbedding(
291
289
  config.vocab_size, config.hidden_size
292
290
  )
@@ -294,7 +292,7 @@ class Phi3SmallModel(nn.Module):
294
292
  self.start_layer, self.end_layer, self.layers = make_layers(
295
293
  config.num_hidden_layers,
296
294
  lambda prefix: Phi3SmallDecoderLayer(
297
- config, int(prefix.split(".")[-1]), cache_config, quant_config
295
+ config, int(prefix.split(".")[-1]), quant_config
298
296
  ),
299
297
  prefix=f"{prefix}.layers",
300
298
  )
@@ -339,7 +337,6 @@ class Phi3SmallForCausalLM(nn.Module):
339
337
  self,
340
338
  config: Phi3Config,
341
339
  quant_config: Optional[QuantizationConfig] = None,
342
- cache_config=None,
343
340
  ):
344
341
 
345
342
  super().__init__()
@@ -397,10 +394,13 @@ class Phi3SmallForCausalLM(nn.Module):
397
394
 
398
395
  def compute_logits(
399
396
  self,
397
+ input_ids: torch.LongTensor,
400
398
  hidden_states: torch.Tensor,
401
399
  sampling_metadata,
402
400
  ) -> Optional[torch.Tensor]:
403
- logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
401
+ logits = self.logits_processor(
402
+ input_ids, self.lm_head, hidden_states, sampling_metadata
403
+ )
404
404
  if self.dummy_token_indices is not None and logits is not None:
405
405
  logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
406
406
  return logits
@@ -422,7 +422,7 @@ class Phi3SmallForCausalLM(nn.Module):
422
422
 
423
423
  if not get_embedding:
424
424
  return self.logits_processor(
425
- input_ids, hidden_states, self.lm_head.weight, forward_batch
425
+ input_ids, hidden_states, self.lm_head, forward_batch
426
426
  )
427
427
 
428
428
  else:
sglang/srt/models/qwen.py CHANGED
@@ -22,7 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
39
38
  VocabParallelEmbedding,
40
39
  )
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
42
42
 
43
43
 
44
44
  class QWenMLP(nn.Module):
@@ -242,7 +242,6 @@ class QWenLMHeadModel(nn.Module):
242
242
  self,
243
243
  config: PretrainedConfig,
244
244
  quant_config: Optional[QuantizationConfig] = None,
245
- cache_config=None,
246
245
  ):
247
246
  super().__init__()
248
247
  self.config = config
@@ -260,7 +259,7 @@ class QWenLMHeadModel(nn.Module):
260
259
  ):
261
260
  hidden_states = self.transformer(input_ids, positions, forward_batch)
262
261
  return self.logits_processor(
263
- input_ids, hidden_states, self.lm_head.weight, forward_batch
262
+ input_ids, hidden_states, self.lm_head, forward_batch
264
263
  )
265
264
 
266
265
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  VocabParallelEmbedding,
41
40
  )
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.utils import make_layers
44
44
 
45
45
  Qwen2Config = None
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
230
230
  self.embed_tokens = VocabParallelEmbedding(
231
231
  config.vocab_size,
232
232
  config.hidden_size,
233
+ quant_config=quant_config,
233
234
  )
234
235
  self.layers = make_layers(
235
236
  config.num_hidden_layers,
@@ -270,13 +271,17 @@ class Qwen2ForCausalLM(nn.Module):
270
271
  self,
271
272
  config: Qwen2Config,
272
273
  quant_config: Optional[QuantizationConfig] = None,
273
- cache_config=None,
274
274
  ) -> None:
275
275
  super().__init__()
276
276
  self.config = config
277
277
  self.quant_config = quant_config
278
278
  self.model = Qwen2Model(config, quant_config=quant_config)
279
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
279
+ if config.tie_word_embeddings:
280
+ self.lm_head = self.model.embed_tokens
281
+ else:
282
+ self.lm_head = ParallelLMHead(
283
+ config.vocab_size, config.hidden_size, quant_config=quant_config
284
+ )
280
285
  self.logits_processor = LogitsProcessor(config)
281
286
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
282
287
 
@@ -292,7 +297,7 @@ class Qwen2ForCausalLM(nn.Module):
292
297
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
293
298
  if not get_embedding:
294
299
  return self.logits_processor(
295
- input_ids, hidden_states, self.lm_head.weight, forward_batch
300
+ input_ids, hidden_states, self.lm_head, forward_batch
296
301
  )
297
302
  else:
298
303
  return self.pooler(hidden_states, forward_batch)
@@ -306,6 +311,7 @@ class Qwen2ForCausalLM(nn.Module):
306
311
  ("gate_up_proj", "gate_proj", 0),
307
312
  ("gate_up_proj", "up_proj", 1),
308
313
  ]
314
+
309
315
  params_dict = dict(self.named_parameters())
310
316
  for name, loaded_weight in weights:
311
317
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -335,11 +341,6 @@ class Qwen2ForCausalLM(nn.Module):
335
341
  param = params_dict[name]
336
342
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
337
343
  weight_loader(param, loaded_weight)
338
- if (
339
- self.config.tie_word_embeddings
340
- and name == "model.embed_tokens.weight"
341
- ):
342
- weight_loader(params_dict["lm_head.weight"], loaded_weight)
343
344
 
344
345
 
345
346
  EntryClass = Qwen2ForCausalLM
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
47
  )
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
51
51
 
52
52
 
53
53
  class Qwen2MoeMLP(nn.Module):
@@ -158,7 +158,6 @@ class Qwen2MoeAttention(nn.Module):
158
158
  rope_theta: float = 10000,
159
159
  rope_scaling: Optional[Dict[str, Any]] = None,
160
160
  max_position_embeddings: int = 8192,
161
- cache_config=None,
162
161
  quant_config: Optional[QuantizationConfig] = None,
163
162
  ) -> None:
164
163
  super().__init__()
@@ -234,7 +233,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
234
233
  self,
235
234
  config: PretrainedConfig,
236
235
  layer_id: int,
237
- cache_config=None,
238
236
  quant_config: Optional[QuantizationConfig] = None,
239
237
  ) -> None:
240
238
  super().__init__()
@@ -250,7 +248,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
250
248
  rope_theta=rope_theta,
251
249
  rope_scaling=rope_scaling,
252
250
  max_position_embeddings=max_position_embeddings,
253
- cache_config=cache_config,
254
251
  quant_config=quant_config,
255
252
  )
256
253
 
@@ -304,7 +301,6 @@ class Qwen2MoeModel(nn.Module):
304
301
  def __init__(
305
302
  self,
306
303
  config: PretrainedConfig,
307
- cache_config=None,
308
304
  quant_config: Optional[QuantizationConfig] = None,
309
305
  ) -> None:
310
306
  super().__init__()
@@ -317,9 +313,7 @@ class Qwen2MoeModel(nn.Module):
317
313
  )
318
314
  self.layers = nn.ModuleList(
319
315
  [
320
- Qwen2MoeDecoderLayer(
321
- config, layer_id, cache_config, quant_config=quant_config
322
- )
316
+ Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
323
317
  for layer_id in range(config.num_hidden_layers)
324
318
  ]
325
319
  )
@@ -353,14 +347,13 @@ class Qwen2MoeForCausalLM(nn.Module):
353
347
  def __init__(
354
348
  self,
355
349
  config: PretrainedConfig,
356
- cache_config=None,
357
350
  quant_config: Optional[QuantizationConfig] = None,
358
351
  ) -> None:
359
352
  super().__init__()
360
353
  self.config = config
361
354
  self.quant_config = quant_config
362
355
  self.torchao_config = global_server_args_dict["torchao_config"]
363
- self.model = Qwen2MoeModel(config, cache_config, quant_config)
356
+ self.model = Qwen2MoeModel(config, quant_config)
364
357
  self.lm_head = ParallelLMHead(
365
358
  config.vocab_size, config.hidden_size, quant_config=quant_config
366
359
  )
@@ -376,7 +369,7 @@ class Qwen2MoeForCausalLM(nn.Module):
376
369
  ) -> torch.Tensor:
377
370
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
378
371
  return self.logits_processor(
379
- input_ids, hidden_states, self.lm_head.weight, forward_batch
372
+ input_ids, hidden_states, self.lm_head, forward_batch
380
373
  )
381
374
 
382
375
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):