sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +976 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -2
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +39 -24
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
- sglang/srt/managers/controller/infer_batch.py +90 -63
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +41 -26
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +136 -149
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +32 -11
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +81 -23
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +132 -84
- sglang/srt/server_args.py +35 -21
- sglang/srt/utils.py +65 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
- sglang-0.1.24.dist-info/RECORD +105 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
601
601
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
602
602
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
603
603
|
|
604
|
+
@torch.no_grad()
|
604
605
|
def forward(
|
605
606
|
self,
|
606
607
|
input_ids: torch.Tensor,
|
@@ -0,0 +1,317 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
3
|
+
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import PretrainedConfig
|
9
|
+
from vllm.config import CacheConfig
|
10
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
11
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
12
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
13
|
+
from vllm.model_executor.layers.linear import (
|
14
|
+
MergedColumnParallelLinear,
|
15
|
+
QKVParallelLinear,
|
16
|
+
RowParallelLinear,
|
17
|
+
)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
19
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
20
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
21
|
+
ParallelLMHead,
|
22
|
+
VocabParallelEmbedding,
|
23
|
+
)
|
24
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
25
|
+
|
26
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
27
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
28
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
29
|
+
|
30
|
+
|
31
|
+
class InternLM2MLP(nn.Module):
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
hidden_size: int,
|
36
|
+
intermediate_size: int,
|
37
|
+
hidden_act: str,
|
38
|
+
quant_config: Optional[QuantizationConfig] = None,
|
39
|
+
) -> None:
|
40
|
+
super().__init__()
|
41
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
42
|
+
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
43
|
+
)
|
44
|
+
self.w2 = RowParallelLinear(
|
45
|
+
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
46
|
+
)
|
47
|
+
if hidden_act != "silu":
|
48
|
+
raise ValueError(
|
49
|
+
f"Unsupported activation: {hidden_act}. "
|
50
|
+
"Only silu is supported for now."
|
51
|
+
)
|
52
|
+
self.act_fn = SiluAndMul()
|
53
|
+
|
54
|
+
def forward(self, x):
|
55
|
+
gate_up, _ = self.gate_up_proj(x)
|
56
|
+
x = self.act_fn(gate_up)
|
57
|
+
x, _ = self.w2(x)
|
58
|
+
return x
|
59
|
+
|
60
|
+
|
61
|
+
class InternLM2Attention(nn.Module):
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
hidden_size: int,
|
66
|
+
num_heads: int,
|
67
|
+
num_kv_heads: int,
|
68
|
+
rope_theta: float = 10000,
|
69
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
70
|
+
max_position_embeddings: int = 8192,
|
71
|
+
layer_id: int = 0,
|
72
|
+
quant_config: Optional[QuantizationConfig] = None,
|
73
|
+
) -> None:
|
74
|
+
super().__init__()
|
75
|
+
self.hidden_size = hidden_size
|
76
|
+
tp_size = get_tensor_model_parallel_world_size()
|
77
|
+
self.total_num_heads = num_heads
|
78
|
+
assert self.total_num_heads % tp_size == 0
|
79
|
+
self.num_heads = self.total_num_heads // tp_size
|
80
|
+
self.total_num_kv_heads = num_kv_heads
|
81
|
+
if self.total_num_kv_heads >= tp_size:
|
82
|
+
# Number of KV heads is greater than TP size, so we partition
|
83
|
+
# the KV heads across multiple tensor parallel GPUs.
|
84
|
+
assert self.total_num_kv_heads % tp_size == 0
|
85
|
+
else:
|
86
|
+
# Number of KV heads is less than TP size, so we replicate
|
87
|
+
# the KV heads across multiple tensor parallel GPUs.
|
88
|
+
assert tp_size % self.total_num_kv_heads == 0
|
89
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
90
|
+
self.head_dim = hidden_size // self.total_num_heads
|
91
|
+
self.q_size = self.num_heads * self.head_dim
|
92
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
93
|
+
self.scaling = self.head_dim**-0.5
|
94
|
+
self.rope_theta = rope_theta
|
95
|
+
self.max_position_embeddings = max_position_embeddings
|
96
|
+
|
97
|
+
self.wqkv = QKVParallelLinear(
|
98
|
+
hidden_size,
|
99
|
+
self.head_dim,
|
100
|
+
self.total_num_heads,
|
101
|
+
self.total_num_kv_heads,
|
102
|
+
bias=False,
|
103
|
+
quant_config=quant_config,
|
104
|
+
)
|
105
|
+
self.wo = RowParallelLinear(
|
106
|
+
self.total_num_heads * self.head_dim,
|
107
|
+
hidden_size,
|
108
|
+
bias=False,
|
109
|
+
quant_config=quant_config,
|
110
|
+
)
|
111
|
+
|
112
|
+
self.rotary_emb = get_rope(
|
113
|
+
self.head_dim,
|
114
|
+
rotary_dim=self.head_dim,
|
115
|
+
max_position=max_position_embeddings,
|
116
|
+
base=rope_theta,
|
117
|
+
rope_scaling=rope_scaling,
|
118
|
+
)
|
119
|
+
self.attn = RadixAttention(
|
120
|
+
self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id
|
121
|
+
)
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
positions: torch.Tensor,
|
126
|
+
hidden_states: torch.Tensor,
|
127
|
+
input_metadata: InputMetadata,
|
128
|
+
) -> torch.Tensor:
|
129
|
+
qkv, _ = self.wqkv(hidden_states)
|
130
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
131
|
+
q, k = self.rotary_emb(positions, q, k)
|
132
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
133
|
+
output, _ = self.wo(attn_output)
|
134
|
+
return output
|
135
|
+
|
136
|
+
|
137
|
+
class InternLMDecoderLayer(nn.Module):
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
config: PretrainedConfig,
|
142
|
+
layer_id: int = 0,
|
143
|
+
quant_config: Optional[QuantizationConfig] = None,
|
144
|
+
) -> None:
|
145
|
+
super().__init__()
|
146
|
+
self.hidden_size = config.hidden_size
|
147
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
148
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
149
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
150
|
+
self.attention = InternLM2Attention(
|
151
|
+
hidden_size=self.hidden_size,
|
152
|
+
num_heads=config.num_attention_heads,
|
153
|
+
num_kv_heads=config.num_key_value_heads,
|
154
|
+
rope_theta=rope_theta,
|
155
|
+
rope_scaling=rope_scaling,
|
156
|
+
max_position_embeddings=max_position_embeddings,
|
157
|
+
layer_id=layer_id,
|
158
|
+
quant_config=quant_config,
|
159
|
+
)
|
160
|
+
self.feed_forward = InternLM2MLP(
|
161
|
+
hidden_size=self.hidden_size,
|
162
|
+
intermediate_size=config.intermediate_size,
|
163
|
+
hidden_act=config.hidden_act,
|
164
|
+
quant_config=quant_config,
|
165
|
+
)
|
166
|
+
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
167
|
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
168
|
+
|
169
|
+
def forward(
|
170
|
+
self,
|
171
|
+
positions: torch.Tensor,
|
172
|
+
hidden_states: torch.Tensor,
|
173
|
+
input_metadata: InputMetadata,
|
174
|
+
residual: Optional[torch.Tensor],
|
175
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
176
|
+
# Self Attention
|
177
|
+
if residual is None:
|
178
|
+
residual = hidden_states
|
179
|
+
hidden_states = self.attention_norm(hidden_states)
|
180
|
+
else:
|
181
|
+
hidden_states, residual = self.attention_norm(hidden_states, residual)
|
182
|
+
hidden_states = self.attention(
|
183
|
+
positions=positions,
|
184
|
+
hidden_states=hidden_states,
|
185
|
+
input_metadata=input_metadata,
|
186
|
+
)
|
187
|
+
|
188
|
+
# Fully Connected
|
189
|
+
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
190
|
+
hidden_states = self.feed_forward(hidden_states)
|
191
|
+
return hidden_states, residual
|
192
|
+
|
193
|
+
|
194
|
+
class InternLM2Model(nn.Module):
|
195
|
+
|
196
|
+
def __init__(
|
197
|
+
self,
|
198
|
+
config: PretrainedConfig,
|
199
|
+
quant_config: Optional[QuantizationConfig] = None,
|
200
|
+
) -> None:
|
201
|
+
super().__init__()
|
202
|
+
self.config = config
|
203
|
+
self.padding_idx = config.pad_token_id
|
204
|
+
self.vocab_size = config.vocab_size
|
205
|
+
self.tok_embeddings = VocabParallelEmbedding(
|
206
|
+
config.vocab_size,
|
207
|
+
config.hidden_size,
|
208
|
+
)
|
209
|
+
self.layers = nn.ModuleList(
|
210
|
+
[
|
211
|
+
InternLMDecoderLayer(config, i, quant_config)
|
212
|
+
for i in range(config.num_hidden_layers)
|
213
|
+
]
|
214
|
+
)
|
215
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
216
|
+
|
217
|
+
def forward(
|
218
|
+
self,
|
219
|
+
input_ids: torch.Tensor,
|
220
|
+
positions: torch.Tensor,
|
221
|
+
input_metadata: InputMetadata,
|
222
|
+
input_embeds: torch.Tensor = None,
|
223
|
+
) -> torch.Tensor:
|
224
|
+
if input_embeds is None:
|
225
|
+
hidden_states = self.tok_embeddings(input_ids)
|
226
|
+
else:
|
227
|
+
hidden_states = input_embeds
|
228
|
+
residual = None
|
229
|
+
for i in range(len(self.layers)):
|
230
|
+
layer = self.layers[i]
|
231
|
+
hidden_states, residual = layer(
|
232
|
+
positions,
|
233
|
+
hidden_states,
|
234
|
+
input_metadata,
|
235
|
+
residual,
|
236
|
+
)
|
237
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
238
|
+
return hidden_states
|
239
|
+
|
240
|
+
|
241
|
+
class InternLM2ForCausalLM(nn.Module):
|
242
|
+
|
243
|
+
def __init__(
|
244
|
+
self,
|
245
|
+
config: PretrainedConfig,
|
246
|
+
quant_config: Optional[QuantizationConfig] = None,
|
247
|
+
cache_config: Optional[CacheConfig] = None,
|
248
|
+
) -> None:
|
249
|
+
super().__init__()
|
250
|
+
self.config = config
|
251
|
+
self.quant_config = quant_config
|
252
|
+
self.model = InternLM2Model(config, quant_config)
|
253
|
+
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
254
|
+
self.logits_processor = LogitsProcessor(config)
|
255
|
+
|
256
|
+
@torch.no_grad()
|
257
|
+
def forward(
|
258
|
+
self,
|
259
|
+
input_ids: torch.Tensor,
|
260
|
+
positions: torch.Tensor,
|
261
|
+
input_metadata: InputMetadata,
|
262
|
+
input_embeds: torch.Tensor = None,
|
263
|
+
) -> torch.Tensor:
|
264
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
265
|
+
return self.logits_processor(
|
266
|
+
input_ids, hidden_states, self.output.weight, input_metadata
|
267
|
+
)
|
268
|
+
|
269
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
270
|
+
stacked_params_mapping = [
|
271
|
+
# (param_name, shard_name, shard_id)
|
272
|
+
("gate_up_proj", "w1", 0),
|
273
|
+
("gate_up_proj", "w3", 1),
|
274
|
+
]
|
275
|
+
params_dict = dict(self.named_parameters())
|
276
|
+
for name, loaded_weight in weights:
|
277
|
+
if "rotary_emb.inv_freq" in name:
|
278
|
+
continue
|
279
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
280
|
+
if weight_name not in name:
|
281
|
+
continue
|
282
|
+
name = name.replace(weight_name, param_name)
|
283
|
+
# Skip loading extra bias for GPTQ models.
|
284
|
+
if name.endswith(".bias") and name not in params_dict:
|
285
|
+
continue
|
286
|
+
param = params_dict[name]
|
287
|
+
weight_loader = param.weight_loader
|
288
|
+
weight_loader(param, loaded_weight, shard_id)
|
289
|
+
break
|
290
|
+
else:
|
291
|
+
# Skip loading extra bias for GPTQ models.
|
292
|
+
if name.endswith(".bias") and name not in params_dict:
|
293
|
+
continue
|
294
|
+
param = params_dict[name]
|
295
|
+
if "wqkv" in name:
|
296
|
+
config = self.config
|
297
|
+
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
298
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
299
|
+
loaded_weight = loaded_weight.view(
|
300
|
+
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
301
|
+
)
|
302
|
+
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
303
|
+
wq = wq.reshape(-1, wq.shape[-1])
|
304
|
+
wk = wk.reshape(-1, wk.shape[-1])
|
305
|
+
wv = wv.reshape(-1, wv.shape[-1])
|
306
|
+
weight_loader = param.weight_loader
|
307
|
+
weight_loader(param, wq, "q")
|
308
|
+
weight_loader(param, wk, "k")
|
309
|
+
weight_loader(param, wv, "v")
|
310
|
+
else:
|
311
|
+
weight_loader = getattr(
|
312
|
+
param, "weight_loader", default_weight_loader
|
313
|
+
)
|
314
|
+
weight_loader(param, loaded_weight)
|
315
|
+
|
316
|
+
|
317
|
+
EntryClass = InternLM2ForCausalLM
|
sglang/srt/models/llama2.py
CHANGED
@@ -5,21 +5,12 @@
|
|
5
5
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
import tqdm
|
9
8
|
from torch import nn
|
10
9
|
from transformers import LlamaConfig
|
11
10
|
from vllm.config import CacheConfig
|
12
|
-
from vllm.distributed import
|
13
|
-
get_tensor_model_parallel_rank,
|
14
|
-
get_tensor_model_parallel_world_size,
|
15
|
-
)
|
11
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
16
12
|
from vllm.model_executor.layers.activation import SiluAndMul
|
17
13
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
18
|
-
from vllm.model_executor.layers.linear import (
|
19
|
-
MergedColumnParallelLinear,
|
20
|
-
QKVParallelLinear,
|
21
|
-
RowParallelLinear,
|
22
|
-
)
|
23
14
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
24
15
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
16
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -32,6 +23,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
32
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
33
24
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
34
25
|
|
26
|
+
MergedColumnParallelLinear = None
|
27
|
+
QKVParallelLinear = None
|
28
|
+
RowParallelLinear = None
|
29
|
+
|
35
30
|
|
36
31
|
class LlamaMLP(nn.Module):
|
37
32
|
def __init__(
|
@@ -40,6 +35,7 @@ class LlamaMLP(nn.Module):
|
|
40
35
|
intermediate_size: int,
|
41
36
|
hidden_act: str,
|
42
37
|
quant_config: Optional[QuantizationConfig] = None,
|
38
|
+
prefix: str = "",
|
43
39
|
) -> None:
|
44
40
|
super().__init__()
|
45
41
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -47,12 +43,14 @@ class LlamaMLP(nn.Module):
|
|
47
43
|
[intermediate_size] * 2,
|
48
44
|
bias=False,
|
49
45
|
quant_config=quant_config,
|
46
|
+
prefix=f"{prefix}.gate_up_proj",
|
50
47
|
)
|
51
48
|
self.down_proj = RowParallelLinear(
|
52
49
|
intermediate_size,
|
53
50
|
hidden_size,
|
54
51
|
bias=False,
|
55
52
|
quant_config=quant_config,
|
53
|
+
prefix=f"{prefix}.down_proj",
|
56
54
|
)
|
57
55
|
if hidden_act != "silu":
|
58
56
|
raise ValueError(
|
@@ -71,6 +69,7 @@ class LlamaMLP(nn.Module):
|
|
71
69
|
class LlamaAttention(nn.Module):
|
72
70
|
def __init__(
|
73
71
|
self,
|
72
|
+
config: LlamaConfig,
|
74
73
|
hidden_size: int,
|
75
74
|
num_heads: int,
|
76
75
|
num_kv_heads: int,
|
@@ -80,6 +79,7 @@ class LlamaAttention(nn.Module):
|
|
80
79
|
rope_is_neox_style: bool = True,
|
81
80
|
max_position_embeddings: int = 8192,
|
82
81
|
quant_config: Optional[QuantizationConfig] = None,
|
82
|
+
prefix: str = "",
|
83
83
|
) -> None:
|
84
84
|
super().__init__()
|
85
85
|
self.hidden_size = hidden_size
|
@@ -97,7 +97,10 @@ class LlamaAttention(nn.Module):
|
|
97
97
|
# the KV heads across multiple tensor parallel GPUs.
|
98
98
|
assert tp_size % self.total_num_kv_heads == 0
|
99
99
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
100
|
-
|
100
|
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
101
|
+
self.head_dim = getattr(
|
102
|
+
config, "head_dim", self.hidden_size // self.total_num_heads
|
103
|
+
)
|
101
104
|
self.q_size = self.num_heads * self.head_dim
|
102
105
|
self.kv_size = self.num_kv_heads * self.head_dim
|
103
106
|
self.scaling = self.head_dim**-0.5
|
@@ -111,12 +114,14 @@ class LlamaAttention(nn.Module):
|
|
111
114
|
self.total_num_kv_heads,
|
112
115
|
bias=False,
|
113
116
|
quant_config=quant_config,
|
117
|
+
prefix=f"{prefix}.qkv_proj",
|
114
118
|
)
|
115
119
|
self.o_proj = RowParallelLinear(
|
116
120
|
self.total_num_heads * self.head_dim,
|
117
121
|
hidden_size,
|
118
122
|
bias=False,
|
119
123
|
quant_config=quant_config,
|
124
|
+
prefix=f"{prefix}.o_proj",
|
120
125
|
)
|
121
126
|
|
122
127
|
self.rotary_emb = get_rope(
|
@@ -155,6 +160,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
155
160
|
config: LlamaConfig,
|
156
161
|
layer_id: int = 0,
|
157
162
|
quant_config: Optional[QuantizationConfig] = None,
|
163
|
+
prefix: str = "",
|
158
164
|
) -> None:
|
159
165
|
super().__init__()
|
160
166
|
self.hidden_size = config.hidden_size
|
@@ -163,12 +169,13 @@ class LlamaDecoderLayer(nn.Module):
|
|
163
169
|
if rope_scaling is not None and getattr(
|
164
170
|
config, "original_max_position_embeddings", None
|
165
171
|
):
|
166
|
-
rope_scaling[
|
167
|
-
|
168
|
-
|
172
|
+
rope_scaling["original_max_position_embeddings"] = (
|
173
|
+
config.original_max_position_embeddings
|
174
|
+
)
|
169
175
|
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
170
176
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
171
177
|
self.self_attn = LlamaAttention(
|
178
|
+
config=config,
|
172
179
|
hidden_size=self.hidden_size,
|
173
180
|
num_heads=config.num_attention_heads,
|
174
181
|
num_kv_heads=config.num_key_value_heads,
|
@@ -178,12 +185,14 @@ class LlamaDecoderLayer(nn.Module):
|
|
178
185
|
rope_is_neox_style=rope_is_neox_style,
|
179
186
|
max_position_embeddings=max_position_embeddings,
|
180
187
|
quant_config=quant_config,
|
188
|
+
prefix=f"{prefix}.self_attn",
|
181
189
|
)
|
182
190
|
self.mlp = LlamaMLP(
|
183
191
|
hidden_size=self.hidden_size,
|
184
192
|
intermediate_size=config.intermediate_size,
|
185
193
|
hidden_act=config.hidden_act,
|
186
194
|
quant_config=quant_config,
|
195
|
+
prefix=f"{prefix}.mlp",
|
187
196
|
)
|
188
197
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
189
198
|
self.post_attention_layernorm = RMSNorm(
|
@@ -231,7 +240,9 @@ class LlamaModel(nn.Module):
|
|
231
240
|
)
|
232
241
|
self.layers = nn.ModuleList(
|
233
242
|
[
|
234
|
-
LlamaDecoderLayer(
|
243
|
+
LlamaDecoderLayer(
|
244
|
+
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
245
|
+
)
|
235
246
|
for i in range(config.num_hidden_layers)
|
236
247
|
]
|
237
248
|
)
|
@@ -267,7 +278,25 @@ class LlamaForCausalLM(nn.Module):
|
|
267
278
|
config: LlamaConfig,
|
268
279
|
quant_config: Optional[QuantizationConfig] = None,
|
269
280
|
cache_config: Optional[CacheConfig] = None,
|
281
|
+
efficient_weight_load=False,
|
270
282
|
) -> None:
|
283
|
+
global MergedColumnParallelLinear
|
284
|
+
global QKVParallelLinear
|
285
|
+
global RowParallelLinear
|
286
|
+
|
287
|
+
if efficient_weight_load:
|
288
|
+
from sglang.srt.layers.linear import (
|
289
|
+
MergedColumnParallelLinear,
|
290
|
+
QKVParallelLinear,
|
291
|
+
RowParallelLinear,
|
292
|
+
)
|
293
|
+
else:
|
294
|
+
from vllm.model_executor.layers.linear import (
|
295
|
+
MergedColumnParallelLinear,
|
296
|
+
QKVParallelLinear,
|
297
|
+
RowParallelLinear,
|
298
|
+
)
|
299
|
+
|
271
300
|
super().__init__()
|
272
301
|
self.config = config
|
273
302
|
self.quant_config = quant_config
|
@@ -275,6 +304,7 @@ class LlamaForCausalLM(nn.Module):
|
|
275
304
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
276
305
|
self.logits_processor = LogitsProcessor(config)
|
277
306
|
|
307
|
+
@torch.no_grad()
|
278
308
|
def forward(
|
279
309
|
self,
|
280
310
|
input_ids: torch.Tensor,
|
@@ -287,7 +317,30 @@ class LlamaForCausalLM(nn.Module):
|
|
287
317
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
288
318
|
)
|
289
319
|
|
290
|
-
def
|
320
|
+
def get_module_name(self, name):
|
321
|
+
stacked_params_mapping = [
|
322
|
+
# (param_name, shard_name, shard_id, num_shard)
|
323
|
+
("qkv_proj", "q_proj", "q", 3),
|
324
|
+
("qkv_proj", "k_proj", "k", 3),
|
325
|
+
("qkv_proj", "v_proj", "v", 3),
|
326
|
+
("gate_up_proj", "gate_proj", 0, 2),
|
327
|
+
("gate_up_proj", "up_proj", 1, 2),
|
328
|
+
]
|
329
|
+
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
330
|
+
if weight_name in name:
|
331
|
+
return (
|
332
|
+
name.replace(weight_name, param_name)[: -len(".weight")],
|
333
|
+
num_shard,
|
334
|
+
)
|
335
|
+
return name[: -len(".weight")], 1
|
336
|
+
|
337
|
+
def get_num_params(self):
|
338
|
+
params_dict = dict(self.named_parameters())
|
339
|
+
return len(params_dict)
|
340
|
+
|
341
|
+
def load_weights(
|
342
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
343
|
+
):
|
291
344
|
stacked_params_mapping = [
|
292
345
|
# (param_name, shard_name, shard_id)
|
293
346
|
("qkv_proj", "q_proj", "q"),
|
@@ -297,15 +350,14 @@ class LlamaForCausalLM(nn.Module):
|
|
297
350
|
("gate_up_proj", "up_proj", 1),
|
298
351
|
]
|
299
352
|
params_dict = dict(self.named_parameters())
|
300
|
-
|
301
|
-
|
302
|
-
for name, loaded_weight in weights:
|
353
|
+
|
354
|
+
def load_weights_per_param(name, loaded_weight):
|
303
355
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
304
|
-
|
356
|
+
return
|
305
357
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
306
358
|
# Models trained using ColossalAI may include these tensors in
|
307
359
|
# the checkpoint. Skip them.
|
308
|
-
|
360
|
+
return
|
309
361
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
310
362
|
if weight_name not in name:
|
311
363
|
continue
|
@@ -322,12 +374,18 @@ class LlamaForCausalLM(nn.Module):
|
|
322
374
|
else:
|
323
375
|
# Skip loading extra bias for GPTQ models.
|
324
376
|
if name.endswith(".bias") and name not in params_dict:
|
325
|
-
|
377
|
+
return
|
326
378
|
if name.startswith("model.vision_tower") and name not in params_dict:
|
327
|
-
|
379
|
+
return
|
328
380
|
param = params_dict[name]
|
329
381
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
330
382
|
weight_loader(param, loaded_weight)
|
331
383
|
|
384
|
+
if name is None or loaded_weight is None:
|
385
|
+
for name, loaded_weight in weights:
|
386
|
+
load_weights_per_param(name, loaded_weight)
|
387
|
+
else:
|
388
|
+
load_weights_per_param(name, loaded_weight)
|
389
|
+
|
332
390
|
|
333
391
|
EntryClass = LlamaForCausalLM
|
sglang/srt/models/llava.py
CHANGED
sglang/srt/models/llavavid.py
CHANGED
sglang/srt/models/minicpm.py
CHANGED
sglang/srt/models/mixtral.py
CHANGED
sglang/srt/models/qwen.py
CHANGED
sglang/srt/models/qwen2.py
CHANGED
@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
261
261
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
262
262
|
self.logits_processor = LogitsProcessor(config)
|
263
263
|
|
264
|
+
@torch.no_grad()
|
264
265
|
def forward(
|
265
266
|
self,
|
266
267
|
input_ids: torch.Tensor,
|
@@ -312,6 +313,11 @@ class Qwen2ForCausalLM(nn.Module):
|
|
312
313
|
param = params_dict[name]
|
313
314
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
314
315
|
weight_loader(param, loaded_weight)
|
316
|
+
if (
|
317
|
+
self.config.tie_word_embeddings
|
318
|
+
and name == "model.embed_tokens.weight"
|
319
|
+
):
|
320
|
+
weight_loader(params_dict["lm_head.weight"], loaded_weight)
|
315
321
|
|
316
322
|
|
317
323
|
EntryClass = Qwen2ForCausalLM
|