sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
601
601
  # Monkey patch _prepare_weights to load pre-sharded weights
602
602
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
603
603
 
604
+ @torch.no_grad()
604
605
  def forward(
605
606
  self,
606
607
  input_ids: torch.Tensor,
@@ -0,0 +1,317 @@
1
+ # -*- coding: utf-8 -*-
2
+ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
3
+
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import PretrainedConfig
9
+ from vllm.config import CacheConfig
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
11
+ from vllm.model_executor.layers.activation import SiluAndMul
12
+ from vllm.model_executor.layers.layernorm import RMSNorm
13
+ from vllm.model_executor.layers.linear import (
14
+ MergedColumnParallelLinear,
15
+ QKVParallelLinear,
16
+ RowParallelLinear,
17
+ )
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
+ from vllm.model_executor.layers.rotary_embedding import get_rope
20
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
21
+ ParallelLMHead,
22
+ VocabParallelEmbedding,
23
+ )
24
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
+
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.radix_attention import RadixAttention
28
+ from sglang.srt.managers.controller.model_runner import InputMetadata
29
+
30
+
31
+ class InternLM2MLP(nn.Module):
32
+
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ intermediate_size: int,
37
+ hidden_act: str,
38
+ quant_config: Optional[QuantizationConfig] = None,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.gate_up_proj = MergedColumnParallelLinear(
42
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
43
+ )
44
+ self.w2 = RowParallelLinear(
45
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config
46
+ )
47
+ if hidden_act != "silu":
48
+ raise ValueError(
49
+ f"Unsupported activation: {hidden_act}. "
50
+ "Only silu is supported for now."
51
+ )
52
+ self.act_fn = SiluAndMul()
53
+
54
+ def forward(self, x):
55
+ gate_up, _ = self.gate_up_proj(x)
56
+ x = self.act_fn(gate_up)
57
+ x, _ = self.w2(x)
58
+ return x
59
+
60
+
61
+ class InternLM2Attention(nn.Module):
62
+
63
+ def __init__(
64
+ self,
65
+ hidden_size: int,
66
+ num_heads: int,
67
+ num_kv_heads: int,
68
+ rope_theta: float = 10000,
69
+ rope_scaling: Optional[Dict[str, Any]] = None,
70
+ max_position_embeddings: int = 8192,
71
+ layer_id: int = 0,
72
+ quant_config: Optional[QuantizationConfig] = None,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.hidden_size = hidden_size
76
+ tp_size = get_tensor_model_parallel_world_size()
77
+ self.total_num_heads = num_heads
78
+ assert self.total_num_heads % tp_size == 0
79
+ self.num_heads = self.total_num_heads // tp_size
80
+ self.total_num_kv_heads = num_kv_heads
81
+ if self.total_num_kv_heads >= tp_size:
82
+ # Number of KV heads is greater than TP size, so we partition
83
+ # the KV heads across multiple tensor parallel GPUs.
84
+ assert self.total_num_kv_heads % tp_size == 0
85
+ else:
86
+ # Number of KV heads is less than TP size, so we replicate
87
+ # the KV heads across multiple tensor parallel GPUs.
88
+ assert tp_size % self.total_num_kv_heads == 0
89
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
90
+ self.head_dim = hidden_size // self.total_num_heads
91
+ self.q_size = self.num_heads * self.head_dim
92
+ self.kv_size = self.num_kv_heads * self.head_dim
93
+ self.scaling = self.head_dim**-0.5
94
+ self.rope_theta = rope_theta
95
+ self.max_position_embeddings = max_position_embeddings
96
+
97
+ self.wqkv = QKVParallelLinear(
98
+ hidden_size,
99
+ self.head_dim,
100
+ self.total_num_heads,
101
+ self.total_num_kv_heads,
102
+ bias=False,
103
+ quant_config=quant_config,
104
+ )
105
+ self.wo = RowParallelLinear(
106
+ self.total_num_heads * self.head_dim,
107
+ hidden_size,
108
+ bias=False,
109
+ quant_config=quant_config,
110
+ )
111
+
112
+ self.rotary_emb = get_rope(
113
+ self.head_dim,
114
+ rotary_dim=self.head_dim,
115
+ max_position=max_position_embeddings,
116
+ base=rope_theta,
117
+ rope_scaling=rope_scaling,
118
+ )
119
+ self.attn = RadixAttention(
120
+ self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id
121
+ )
122
+
123
+ def forward(
124
+ self,
125
+ positions: torch.Tensor,
126
+ hidden_states: torch.Tensor,
127
+ input_metadata: InputMetadata,
128
+ ) -> torch.Tensor:
129
+ qkv, _ = self.wqkv(hidden_states)
130
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
131
+ q, k = self.rotary_emb(positions, q, k)
132
+ attn_output = self.attn(q, k, v, input_metadata)
133
+ output, _ = self.wo(attn_output)
134
+ return output
135
+
136
+
137
+ class InternLMDecoderLayer(nn.Module):
138
+
139
+ def __init__(
140
+ self,
141
+ config: PretrainedConfig,
142
+ layer_id: int = 0,
143
+ quant_config: Optional[QuantizationConfig] = None,
144
+ ) -> None:
145
+ super().__init__()
146
+ self.hidden_size = config.hidden_size
147
+ rope_theta = getattr(config, "rope_theta", 10000)
148
+ rope_scaling = getattr(config, "rope_scaling", None)
149
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
150
+ self.attention = InternLM2Attention(
151
+ hidden_size=self.hidden_size,
152
+ num_heads=config.num_attention_heads,
153
+ num_kv_heads=config.num_key_value_heads,
154
+ rope_theta=rope_theta,
155
+ rope_scaling=rope_scaling,
156
+ max_position_embeddings=max_position_embeddings,
157
+ layer_id=layer_id,
158
+ quant_config=quant_config,
159
+ )
160
+ self.feed_forward = InternLM2MLP(
161
+ hidden_size=self.hidden_size,
162
+ intermediate_size=config.intermediate_size,
163
+ hidden_act=config.hidden_act,
164
+ quant_config=quant_config,
165
+ )
166
+ self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
167
+ self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
168
+
169
+ def forward(
170
+ self,
171
+ positions: torch.Tensor,
172
+ hidden_states: torch.Tensor,
173
+ input_metadata: InputMetadata,
174
+ residual: Optional[torch.Tensor],
175
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
176
+ # Self Attention
177
+ if residual is None:
178
+ residual = hidden_states
179
+ hidden_states = self.attention_norm(hidden_states)
180
+ else:
181
+ hidden_states, residual = self.attention_norm(hidden_states, residual)
182
+ hidden_states = self.attention(
183
+ positions=positions,
184
+ hidden_states=hidden_states,
185
+ input_metadata=input_metadata,
186
+ )
187
+
188
+ # Fully Connected
189
+ hidden_states, residual = self.ffn_norm(hidden_states, residual)
190
+ hidden_states = self.feed_forward(hidden_states)
191
+ return hidden_states, residual
192
+
193
+
194
+ class InternLM2Model(nn.Module):
195
+
196
+ def __init__(
197
+ self,
198
+ config: PretrainedConfig,
199
+ quant_config: Optional[QuantizationConfig] = None,
200
+ ) -> None:
201
+ super().__init__()
202
+ self.config = config
203
+ self.padding_idx = config.pad_token_id
204
+ self.vocab_size = config.vocab_size
205
+ self.tok_embeddings = VocabParallelEmbedding(
206
+ config.vocab_size,
207
+ config.hidden_size,
208
+ )
209
+ self.layers = nn.ModuleList(
210
+ [
211
+ InternLMDecoderLayer(config, i, quant_config)
212
+ for i in range(config.num_hidden_layers)
213
+ ]
214
+ )
215
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: torch.Tensor,
220
+ positions: torch.Tensor,
221
+ input_metadata: InputMetadata,
222
+ input_embeds: torch.Tensor = None,
223
+ ) -> torch.Tensor:
224
+ if input_embeds is None:
225
+ hidden_states = self.tok_embeddings(input_ids)
226
+ else:
227
+ hidden_states = input_embeds
228
+ residual = None
229
+ for i in range(len(self.layers)):
230
+ layer = self.layers[i]
231
+ hidden_states, residual = layer(
232
+ positions,
233
+ hidden_states,
234
+ input_metadata,
235
+ residual,
236
+ )
237
+ hidden_states, _ = self.norm(hidden_states, residual)
238
+ return hidden_states
239
+
240
+
241
+ class InternLM2ForCausalLM(nn.Module):
242
+
243
+ def __init__(
244
+ self,
245
+ config: PretrainedConfig,
246
+ quant_config: Optional[QuantizationConfig] = None,
247
+ cache_config: Optional[CacheConfig] = None,
248
+ ) -> None:
249
+ super().__init__()
250
+ self.config = config
251
+ self.quant_config = quant_config
252
+ self.model = InternLM2Model(config, quant_config)
253
+ self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
254
+ self.logits_processor = LogitsProcessor(config)
255
+
256
+ @torch.no_grad()
257
+ def forward(
258
+ self,
259
+ input_ids: torch.Tensor,
260
+ positions: torch.Tensor,
261
+ input_metadata: InputMetadata,
262
+ input_embeds: torch.Tensor = None,
263
+ ) -> torch.Tensor:
264
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
265
+ return self.logits_processor(
266
+ input_ids, hidden_states, self.output.weight, input_metadata
267
+ )
268
+
269
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
270
+ stacked_params_mapping = [
271
+ # (param_name, shard_name, shard_id)
272
+ ("gate_up_proj", "w1", 0),
273
+ ("gate_up_proj", "w3", 1),
274
+ ]
275
+ params_dict = dict(self.named_parameters())
276
+ for name, loaded_weight in weights:
277
+ if "rotary_emb.inv_freq" in name:
278
+ continue
279
+ for param_name, weight_name, shard_id in stacked_params_mapping:
280
+ if weight_name not in name:
281
+ continue
282
+ name = name.replace(weight_name, param_name)
283
+ # Skip loading extra bias for GPTQ models.
284
+ if name.endswith(".bias") and name not in params_dict:
285
+ continue
286
+ param = params_dict[name]
287
+ weight_loader = param.weight_loader
288
+ weight_loader(param, loaded_weight, shard_id)
289
+ break
290
+ else:
291
+ # Skip loading extra bias for GPTQ models.
292
+ if name.endswith(".bias") and name not in params_dict:
293
+ continue
294
+ param = params_dict[name]
295
+ if "wqkv" in name:
296
+ config = self.config
297
+ kv_groups = config.num_attention_heads // config.num_key_value_heads
298
+ head_dim = config.hidden_size // config.num_attention_heads
299
+ loaded_weight = loaded_weight.view(
300
+ -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
301
+ )
302
+ wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
303
+ wq = wq.reshape(-1, wq.shape[-1])
304
+ wk = wk.reshape(-1, wk.shape[-1])
305
+ wv = wv.reshape(-1, wv.shape[-1])
306
+ weight_loader = param.weight_loader
307
+ weight_loader(param, wq, "q")
308
+ weight_loader(param, wk, "k")
309
+ weight_loader(param, wv, "v")
310
+ else:
311
+ weight_loader = getattr(
312
+ param, "weight_loader", default_weight_loader
313
+ )
314
+ weight_loader(param, loaded_weight)
315
+
316
+
317
+ EntryClass = InternLM2ForCausalLM
@@ -5,21 +5,12 @@
5
5
  from typing import Any, Dict, Iterable, Optional, Tuple
6
6
 
7
7
  import torch
8
- import tqdm
9
8
  from torch import nn
10
9
  from transformers import LlamaConfig
11
10
  from vllm.config import CacheConfig
12
- from vllm.distributed import (
13
- get_tensor_model_parallel_rank,
14
- get_tensor_model_parallel_world_size,
15
- )
11
+ from vllm.distributed import get_tensor_model_parallel_world_size
16
12
  from vllm.model_executor.layers.activation import SiluAndMul
17
13
  from vllm.model_executor.layers.layernorm import RMSNorm
18
- from vllm.model_executor.layers.linear import (
19
- MergedColumnParallelLinear,
20
- QKVParallelLinear,
21
- RowParallelLinear,
22
- )
23
14
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
24
15
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
16
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -32,6 +23,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
32
23
  from sglang.srt.layers.radix_attention import RadixAttention
33
24
  from sglang.srt.managers.controller.model_runner import InputMetadata
34
25
 
26
+ MergedColumnParallelLinear = None
27
+ QKVParallelLinear = None
28
+ RowParallelLinear = None
29
+
35
30
 
36
31
  class LlamaMLP(nn.Module):
37
32
  def __init__(
@@ -40,6 +35,7 @@ class LlamaMLP(nn.Module):
40
35
  intermediate_size: int,
41
36
  hidden_act: str,
42
37
  quant_config: Optional[QuantizationConfig] = None,
38
+ prefix: str = "",
43
39
  ) -> None:
44
40
  super().__init__()
45
41
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -47,12 +43,14 @@ class LlamaMLP(nn.Module):
47
43
  [intermediate_size] * 2,
48
44
  bias=False,
49
45
  quant_config=quant_config,
46
+ prefix=f"{prefix}.gate_up_proj",
50
47
  )
51
48
  self.down_proj = RowParallelLinear(
52
49
  intermediate_size,
53
50
  hidden_size,
54
51
  bias=False,
55
52
  quant_config=quant_config,
53
+ prefix=f"{prefix}.down_proj",
56
54
  )
57
55
  if hidden_act != "silu":
58
56
  raise ValueError(
@@ -71,6 +69,7 @@ class LlamaMLP(nn.Module):
71
69
  class LlamaAttention(nn.Module):
72
70
  def __init__(
73
71
  self,
72
+ config: LlamaConfig,
74
73
  hidden_size: int,
75
74
  num_heads: int,
76
75
  num_kv_heads: int,
@@ -80,6 +79,7 @@ class LlamaAttention(nn.Module):
80
79
  rope_is_neox_style: bool = True,
81
80
  max_position_embeddings: int = 8192,
82
81
  quant_config: Optional[QuantizationConfig] = None,
82
+ prefix: str = "",
83
83
  ) -> None:
84
84
  super().__init__()
85
85
  self.hidden_size = hidden_size
@@ -97,7 +97,10 @@ class LlamaAttention(nn.Module):
97
97
  # the KV heads across multiple tensor parallel GPUs.
98
98
  assert tp_size % self.total_num_kv_heads == 0
99
99
  self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
100
- self.head_dim = hidden_size // self.total_num_heads
100
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
101
+ self.head_dim = getattr(
102
+ config, "head_dim", self.hidden_size // self.total_num_heads
103
+ )
101
104
  self.q_size = self.num_heads * self.head_dim
102
105
  self.kv_size = self.num_kv_heads * self.head_dim
103
106
  self.scaling = self.head_dim**-0.5
@@ -111,12 +114,14 @@ class LlamaAttention(nn.Module):
111
114
  self.total_num_kv_heads,
112
115
  bias=False,
113
116
  quant_config=quant_config,
117
+ prefix=f"{prefix}.qkv_proj",
114
118
  )
115
119
  self.o_proj = RowParallelLinear(
116
120
  self.total_num_heads * self.head_dim,
117
121
  hidden_size,
118
122
  bias=False,
119
123
  quant_config=quant_config,
124
+ prefix=f"{prefix}.o_proj",
120
125
  )
121
126
 
122
127
  self.rotary_emb = get_rope(
@@ -155,6 +160,7 @@ class LlamaDecoderLayer(nn.Module):
155
160
  config: LlamaConfig,
156
161
  layer_id: int = 0,
157
162
  quant_config: Optional[QuantizationConfig] = None,
163
+ prefix: str = "",
158
164
  ) -> None:
159
165
  super().__init__()
160
166
  self.hidden_size = config.hidden_size
@@ -163,12 +169,13 @@ class LlamaDecoderLayer(nn.Module):
163
169
  if rope_scaling is not None and getattr(
164
170
  config, "original_max_position_embeddings", None
165
171
  ):
166
- rope_scaling[
167
- "original_max_position_embeddings"
168
- ] = config.original_max_position_embeddings
172
+ rope_scaling["original_max_position_embeddings"] = (
173
+ config.original_max_position_embeddings
174
+ )
169
175
  rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
170
176
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
171
177
  self.self_attn = LlamaAttention(
178
+ config=config,
172
179
  hidden_size=self.hidden_size,
173
180
  num_heads=config.num_attention_heads,
174
181
  num_kv_heads=config.num_key_value_heads,
@@ -178,12 +185,14 @@ class LlamaDecoderLayer(nn.Module):
178
185
  rope_is_neox_style=rope_is_neox_style,
179
186
  max_position_embeddings=max_position_embeddings,
180
187
  quant_config=quant_config,
188
+ prefix=f"{prefix}.self_attn",
181
189
  )
182
190
  self.mlp = LlamaMLP(
183
191
  hidden_size=self.hidden_size,
184
192
  intermediate_size=config.intermediate_size,
185
193
  hidden_act=config.hidden_act,
186
194
  quant_config=quant_config,
195
+ prefix=f"{prefix}.mlp",
187
196
  )
188
197
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
189
198
  self.post_attention_layernorm = RMSNorm(
@@ -231,7 +240,9 @@ class LlamaModel(nn.Module):
231
240
  )
232
241
  self.layers = nn.ModuleList(
233
242
  [
234
- LlamaDecoderLayer(config, i, quant_config=quant_config)
243
+ LlamaDecoderLayer(
244
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
245
+ )
235
246
  for i in range(config.num_hidden_layers)
236
247
  ]
237
248
  )
@@ -267,7 +278,25 @@ class LlamaForCausalLM(nn.Module):
267
278
  config: LlamaConfig,
268
279
  quant_config: Optional[QuantizationConfig] = None,
269
280
  cache_config: Optional[CacheConfig] = None,
281
+ efficient_weight_load=False,
270
282
  ) -> None:
283
+ global MergedColumnParallelLinear
284
+ global QKVParallelLinear
285
+ global RowParallelLinear
286
+
287
+ if efficient_weight_load:
288
+ from sglang.srt.layers.linear import (
289
+ MergedColumnParallelLinear,
290
+ QKVParallelLinear,
291
+ RowParallelLinear,
292
+ )
293
+ else:
294
+ from vllm.model_executor.layers.linear import (
295
+ MergedColumnParallelLinear,
296
+ QKVParallelLinear,
297
+ RowParallelLinear,
298
+ )
299
+
271
300
  super().__init__()
272
301
  self.config = config
273
302
  self.quant_config = quant_config
@@ -275,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
275
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
276
305
  self.logits_processor = LogitsProcessor(config)
277
306
 
307
+ @torch.no_grad()
278
308
  def forward(
279
309
  self,
280
310
  input_ids: torch.Tensor,
@@ -287,7 +317,30 @@ class LlamaForCausalLM(nn.Module):
287
317
  input_ids, hidden_states, self.lm_head.weight, input_metadata
288
318
  )
289
319
 
290
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
320
+ def get_module_name(self, name):
321
+ stacked_params_mapping = [
322
+ # (param_name, shard_name, shard_id, num_shard)
323
+ ("qkv_proj", "q_proj", "q", 3),
324
+ ("qkv_proj", "k_proj", "k", 3),
325
+ ("qkv_proj", "v_proj", "v", 3),
326
+ ("gate_up_proj", "gate_proj", 0, 2),
327
+ ("gate_up_proj", "up_proj", 1, 2),
328
+ ]
329
+ for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
330
+ if weight_name in name:
331
+ return (
332
+ name.replace(weight_name, param_name)[: -len(".weight")],
333
+ num_shard,
334
+ )
335
+ return name[: -len(".weight")], 1
336
+
337
+ def get_num_params(self):
338
+ params_dict = dict(self.named_parameters())
339
+ return len(params_dict)
340
+
341
+ def load_weights(
342
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
343
+ ):
291
344
  stacked_params_mapping = [
292
345
  # (param_name, shard_name, shard_id)
293
346
  ("qkv_proj", "q_proj", "q"),
@@ -297,15 +350,14 @@ class LlamaForCausalLM(nn.Module):
297
350
  ("gate_up_proj", "up_proj", 1),
298
351
  ]
299
352
  params_dict = dict(self.named_parameters())
300
- if get_tensor_model_parallel_rank() == 0:
301
- weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
302
- for name, loaded_weight in weights:
353
+
354
+ def load_weights_per_param(name, loaded_weight):
303
355
  if "rotary_emb.inv_freq" in name or "projector" in name:
304
- continue
356
+ return
305
357
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
306
358
  # Models trained using ColossalAI may include these tensors in
307
359
  # the checkpoint. Skip them.
308
- continue
360
+ return
309
361
  for param_name, weight_name, shard_id in stacked_params_mapping:
310
362
  if weight_name not in name:
311
363
  continue
@@ -322,12 +374,18 @@ class LlamaForCausalLM(nn.Module):
322
374
  else:
323
375
  # Skip loading extra bias for GPTQ models.
324
376
  if name.endswith(".bias") and name not in params_dict:
325
- continue
377
+ return
326
378
  if name.startswith("model.vision_tower") and name not in params_dict:
327
- continue
379
+ return
328
380
  param = params_dict[name]
329
381
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
330
382
  weight_loader(param, loaded_weight)
331
383
 
384
+ if name is None or loaded_weight is None:
385
+ for name, loaded_weight in weights:
386
+ load_weights_per_param(name, loaded_weight)
387
+ else:
388
+ load_weights_per_param(name, loaded_weight)
389
+
332
390
 
333
391
  EntryClass = LlamaForCausalLM
@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
31
31
  )
32
32
  self.eos_token_id = config.eos_token_id
33
33
 
34
+ @torch.no_grad()
34
35
  def forward(
35
36
  self,
36
37
  input_ids: torch.Tensor,
@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module):
95
95
 
96
96
  return image_features
97
97
 
98
+ @torch.no_grad()
98
99
  def forward(
99
100
  self,
100
101
  input_ids: torch.LongTensor,
@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module):
106
106
 
107
107
  return image_features
108
108
 
109
+ @torch.no_grad()
109
110
  def forward(
110
111
  self,
111
112
  input_ids: torch.LongTensor,
@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module):
283
283
 
284
284
  self.logits_processor = LogitsProcessor(config)
285
285
 
286
+ @torch.no_grad()
286
287
  def forward(
287
288
  self,
288
289
  input_ids: torch.Tensor,
@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module):
460
460
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
461
461
  self.logits_processor = LogitsProcessor(config)
462
462
 
463
+ @torch.no_grad()
463
464
  def forward(
464
465
  self,
465
466
  input_ids: torch.Tensor,
@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
322
322
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
323
323
  self.logits_processor = LogitsProcessor(config)
324
324
 
325
+ @torch.no_grad()
325
326
  def forward(
326
327
  self,
327
328
  input_ids: torch.Tensor,
sglang/srt/models/qwen.py CHANGED
@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module):
237
237
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
238
238
  self.logits_processor = LogitsProcessor(config)
239
239
 
240
+ @torch.no_grad()
240
241
  def forward(
241
242
  self,
242
243
  input_ids: torch.Tensor,
@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
261
261
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
262
262
  self.logits_processor = LogitsProcessor(config)
263
263
 
264
+ @torch.no_grad()
264
265
  def forward(
265
266
  self,
266
267
  input_ids: torch.Tensor,
@@ -312,6 +313,11 @@ class Qwen2ForCausalLM(nn.Module):
312
313
  param = params_dict[name]
313
314
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
314
315
  weight_loader(param, loaded_weight)
316
+ if (
317
+ self.config.tie_word_embeddings
318
+ and name == "model.embed_tokens.weight"
319
+ ):
320
+ weight_loader(params_dict["lm_head.weight"], loaded_weight)
315
321
 
316
322
 
317
323
  EntryClass = Qwen2ForCausalLM