sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +30 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/backend/runtime_endpoint.py +18 -14
- sglang/bench_latency.py +317 -0
- sglang/global_config.py +5 -1
- sglang/lang/chat_template.py +41 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +6 -2
- sglang/lang/ir.py +74 -28
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +68 -9
- sglang/srt/layers/extend_attention.py +2 -1
- sglang/srt/layers/fused_moe.py +280 -169
- sglang/srt/layers/logits_processor.py +106 -42
- sglang/srt/layers/radix_attention.py +53 -29
- sglang/srt/layers/token_attention.py +4 -1
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +144 -69
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +9 -4
- sglang/srt/managers/controller/model_runner.py +167 -55
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +156 -134
- sglang/srt/managers/detokenizer_manager.py +19 -21
- sglang/srt/managers/io_struct.py +11 -5
- sglang/srt/managers/tokenizer_manager.py +16 -14
- sglang/srt/model_config.py +89 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +12 -5
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +373 -0
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +454 -0
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +35 -25
- sglang/srt/openai_protocol.py +2 -2
- sglang/srt/server.py +69 -19
- sglang/srt/server_args.py +76 -43
- sglang/srt/utils.py +177 -35
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
- sglang-0.1.19.dist-info/RECORD +81 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,436 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
3
|
+
from typing import Iterable, Optional, Set, Tuple, Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn
|
7
|
+
from transformers import PretrainedConfig
|
8
|
+
from vllm.config import CacheConfig, LoRAConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
10
|
+
|
11
|
+
# FIXME: temporary solution, remove after next vllm release
|
12
|
+
from vllm.model_executor.custom_op import CustomOp
|
13
|
+
from vllm.model_executor.layers.activation import GeluAndMul
|
14
|
+
|
15
|
+
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
16
|
+
from vllm.model_executor.layers.linear import (
|
17
|
+
MergedColumnParallelLinear,
|
18
|
+
QKVParallelLinear,
|
19
|
+
RowParallelLinear,
|
20
|
+
)
|
21
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
22
|
+
|
23
|
+
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
24
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
25
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
|
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
27
|
+
|
28
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
31
|
+
|
32
|
+
|
33
|
+
class GemmaRMSNorm(CustomOp):
|
34
|
+
"""RMS normalization for Gemma.
|
35
|
+
|
36
|
+
Two differences from the above RMSNorm:
|
37
|
+
1. x * (1 + w) instead of x * w.
|
38
|
+
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
39
|
+
"""
|
40
|
+
|
41
|
+
def __init__(
|
42
|
+
self,
|
43
|
+
hidden_size: int,
|
44
|
+
eps: float = 1e-6,
|
45
|
+
) -> None:
|
46
|
+
super().__init__()
|
47
|
+
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
48
|
+
self.variance_epsilon = eps
|
49
|
+
|
50
|
+
def forward_native(
|
51
|
+
self,
|
52
|
+
x: torch.Tensor,
|
53
|
+
residual: Optional[torch.Tensor] = None,
|
54
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
55
|
+
"""PyTorch-native implementation equivalent to forward()."""
|
56
|
+
orig_dtype = x.dtype
|
57
|
+
if residual is not None:
|
58
|
+
x = x + residual
|
59
|
+
residual = x
|
60
|
+
|
61
|
+
x = x.float()
|
62
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
63
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
64
|
+
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
65
|
+
# See https://github.com/huggingface/transformers/pull/29402
|
66
|
+
x = x * (1.0 + self.weight.float())
|
67
|
+
x = x.to(orig_dtype)
|
68
|
+
return x if residual is None else (x, residual)
|
69
|
+
|
70
|
+
def forward_cuda(
|
71
|
+
self,
|
72
|
+
x: torch.Tensor,
|
73
|
+
residual: Optional[torch.Tensor] = None,
|
74
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
75
|
+
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
76
|
+
return self.forward_native(x, residual)
|
77
|
+
|
78
|
+
|
79
|
+
# FIXME: temporary solution, remove after next vllm release
|
80
|
+
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
81
|
+
|
82
|
+
|
83
|
+
class GemmaRotaryEmbedding(RotaryEmbedding):
|
84
|
+
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
85
|
+
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
86
|
+
inv_freq = 1.0 / (
|
87
|
+
base
|
88
|
+
** (
|
89
|
+
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
|
90
|
+
/ self.rotary_dim
|
91
|
+
)
|
92
|
+
)
|
93
|
+
return inv_freq
|
94
|
+
|
95
|
+
|
96
|
+
class Gemma2MLP(nn.Module):
|
97
|
+
def __init__(
|
98
|
+
self,
|
99
|
+
hidden_size: int,
|
100
|
+
intermediate_size: int,
|
101
|
+
hidden_act: str,
|
102
|
+
hidden_activation: str,
|
103
|
+
quant_config: Optional[QuantizationConfig] = None,
|
104
|
+
) -> None:
|
105
|
+
super().__init__()
|
106
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
107
|
+
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
108
|
+
)
|
109
|
+
self.down_proj = RowParallelLinear(
|
110
|
+
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
111
|
+
)
|
112
|
+
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
|
113
|
+
raise ValueError(
|
114
|
+
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
|
115
|
+
"function. Please set `hidden_act` and `hidden_activation` to "
|
116
|
+
"`gelu_pytorch_tanh`."
|
117
|
+
)
|
118
|
+
self.act_fn = GeluAndMul(approximate="tanh")
|
119
|
+
|
120
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
121
|
+
gate_up, _ = self.gate_up_proj(x)
|
122
|
+
x = self.act_fn(gate_up)
|
123
|
+
x, _ = self.down_proj(x)
|
124
|
+
return x
|
125
|
+
|
126
|
+
|
127
|
+
class Gemma2Attention(nn.Module):
|
128
|
+
def __init__(
|
129
|
+
self,
|
130
|
+
layer_idx: int,
|
131
|
+
config: PretrainedConfig,
|
132
|
+
hidden_size: int,
|
133
|
+
num_heads: int,
|
134
|
+
num_kv_heads: int,
|
135
|
+
head_dim: int,
|
136
|
+
max_position_embeddings: int,
|
137
|
+
rope_theta: float,
|
138
|
+
cache_config: Optional[CacheConfig] = None,
|
139
|
+
quant_config: Optional[QuantizationConfig] = None,
|
140
|
+
) -> None:
|
141
|
+
super().__init__()
|
142
|
+
self.layer_idx = layer_idx
|
143
|
+
self.config = config
|
144
|
+
self.hidden_size = hidden_size
|
145
|
+
tp_size = get_tensor_model_parallel_world_size()
|
146
|
+
self.total_num_heads = num_heads
|
147
|
+
assert self.total_num_heads % tp_size == 0
|
148
|
+
self.num_heads = self.total_num_heads // tp_size
|
149
|
+
self.total_num_kv_heads = num_kv_heads
|
150
|
+
if self.total_num_kv_heads >= tp_size:
|
151
|
+
# Number of KV heads is greater than TP size, so we partition
|
152
|
+
# the KV heads across multiple tensor parallel GPUs.
|
153
|
+
assert self.total_num_kv_heads % tp_size == 0
|
154
|
+
else:
|
155
|
+
# Number of KV heads is less than TP size, so we replicate
|
156
|
+
# the KV heads across multiple tensor parallel GPUs.
|
157
|
+
assert tp_size % self.total_num_kv_heads == 0
|
158
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
159
|
+
self.head_dim = head_dim
|
160
|
+
self.q_size = self.num_heads * self.head_dim
|
161
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
162
|
+
self.scaling = config.query_pre_attn_scalar**-0.5
|
163
|
+
self.rope_theta = rope_theta
|
164
|
+
|
165
|
+
self.qkv_proj = QKVParallelLinear(
|
166
|
+
hidden_size,
|
167
|
+
self.head_dim,
|
168
|
+
self.total_num_heads,
|
169
|
+
self.total_num_kv_heads,
|
170
|
+
bias=config.attention_bias,
|
171
|
+
quant_config=quant_config,
|
172
|
+
)
|
173
|
+
self.o_proj = RowParallelLinear(
|
174
|
+
self.total_num_heads * self.head_dim,
|
175
|
+
hidden_size,
|
176
|
+
bias=config.attention_bias,
|
177
|
+
quant_config=quant_config,
|
178
|
+
)
|
179
|
+
# from vLLM: TODO(woosuk): Use the `get_rope` interface.
|
180
|
+
self.rotary_emb = GemmaRotaryEmbedding(
|
181
|
+
self.head_dim,
|
182
|
+
self.head_dim,
|
183
|
+
max_position_embeddings,
|
184
|
+
base=self.rope_theta,
|
185
|
+
is_neox_style=True,
|
186
|
+
dtype=torch.get_default_dtype(),
|
187
|
+
)
|
188
|
+
|
189
|
+
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
|
190
|
+
# odd layer, vLLM currently ignores it and uses global attention for
|
191
|
+
# all layers.
|
192
|
+
use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
|
193
|
+
del use_sliding_window # Unused.
|
194
|
+
self.attn = RadixAttention(
|
195
|
+
self.num_heads,
|
196
|
+
self.head_dim,
|
197
|
+
self.scaling,
|
198
|
+
num_kv_heads=self.num_kv_heads,
|
199
|
+
layer_id=layer_idx,
|
200
|
+
logit_cap=self.config.attn_logit_softcapping,
|
201
|
+
)
|
202
|
+
|
203
|
+
def forward(
|
204
|
+
self,
|
205
|
+
positions: torch.Tensor,
|
206
|
+
hidden_states: torch.Tensor,
|
207
|
+
input_metadata: InputMetadata,
|
208
|
+
) -> torch.Tensor:
|
209
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
210
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
211
|
+
q, k = self.rotary_emb(positions, q, k)
|
212
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
213
|
+
output, _ = self.o_proj(attn_output)
|
214
|
+
return output
|
215
|
+
|
216
|
+
|
217
|
+
class Gemma2DecoderLayer(nn.Module):
|
218
|
+
def __init__(
|
219
|
+
self,
|
220
|
+
layer_idx: int,
|
221
|
+
config: PretrainedConfig,
|
222
|
+
cache_config: Optional[CacheConfig] = None,
|
223
|
+
quant_config: Optional[QuantizationConfig] = None,
|
224
|
+
) -> None:
|
225
|
+
super().__init__()
|
226
|
+
self.hidden_size = config.hidden_size
|
227
|
+
self.self_attn = Gemma2Attention(
|
228
|
+
layer_idx=layer_idx,
|
229
|
+
config=config,
|
230
|
+
hidden_size=self.hidden_size,
|
231
|
+
num_heads=config.num_attention_heads,
|
232
|
+
num_kv_heads=config.num_key_value_heads,
|
233
|
+
head_dim=config.head_dim,
|
234
|
+
max_position_embeddings=config.max_position_embeddings,
|
235
|
+
rope_theta=config.rope_theta,
|
236
|
+
cache_config=cache_config,
|
237
|
+
quant_config=quant_config,
|
238
|
+
)
|
239
|
+
self.hidden_size = config.hidden_size
|
240
|
+
self.mlp = Gemma2MLP(
|
241
|
+
hidden_size=self.hidden_size,
|
242
|
+
intermediate_size=config.intermediate_size,
|
243
|
+
hidden_act=config.hidden_act,
|
244
|
+
hidden_activation=config.hidden_activation,
|
245
|
+
quant_config=quant_config,
|
246
|
+
)
|
247
|
+
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
248
|
+
self.post_attention_layernorm = GemmaRMSNorm(
|
249
|
+
config.hidden_size, eps=config.rms_norm_eps
|
250
|
+
)
|
251
|
+
self.pre_feedforward_layernorm = GemmaRMSNorm(
|
252
|
+
config.hidden_size, eps=config.rms_norm_eps
|
253
|
+
)
|
254
|
+
self.post_feedforward_layernorm = GemmaRMSNorm(
|
255
|
+
config.hidden_size, eps=config.rms_norm_eps
|
256
|
+
)
|
257
|
+
|
258
|
+
def forward(
|
259
|
+
self,
|
260
|
+
positions: torch.Tensor,
|
261
|
+
hidden_states: torch.Tensor,
|
262
|
+
input_metadata: InputMetadata,
|
263
|
+
residual: Optional[torch.Tensor],
|
264
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
265
|
+
if residual is None:
|
266
|
+
residual = hidden_states
|
267
|
+
hidden_states = self.input_layernorm(hidden_states)
|
268
|
+
else:
|
269
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
270
|
+
hidden_states = self.self_attn(
|
271
|
+
positions=positions,
|
272
|
+
hidden_states=hidden_states,
|
273
|
+
input_metadata=input_metadata,
|
274
|
+
)
|
275
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
276
|
+
|
277
|
+
hidden_states, residual = self.pre_feedforward_layernorm(
|
278
|
+
hidden_states, residual
|
279
|
+
)
|
280
|
+
hidden_states = self.mlp(hidden_states)
|
281
|
+
hidden_states = self.post_feedforward_layernorm(hidden_states)
|
282
|
+
return hidden_states, residual
|
283
|
+
|
284
|
+
|
285
|
+
class Gemma2Model(nn.Module):
|
286
|
+
def __init__(
|
287
|
+
self,
|
288
|
+
config: PretrainedConfig,
|
289
|
+
cache_config: Optional[CacheConfig] = None,
|
290
|
+
quant_config: Optional[QuantizationConfig] = None,
|
291
|
+
) -> None:
|
292
|
+
super().__init__()
|
293
|
+
self.config = config
|
294
|
+
|
295
|
+
self.embed_tokens = VocabParallelEmbedding(
|
296
|
+
config.vocab_size,
|
297
|
+
config.hidden_size,
|
298
|
+
)
|
299
|
+
self.layers = nn.ModuleList(
|
300
|
+
[
|
301
|
+
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
|
302
|
+
for layer_idx in range(config.num_hidden_layers)
|
303
|
+
]
|
304
|
+
)
|
305
|
+
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
306
|
+
|
307
|
+
# Normalize the embedding by sqrt(hidden_size)
|
308
|
+
# The normalizer's data type should be downcasted to the model's
|
309
|
+
# data type such as bfloat16, not float32.
|
310
|
+
# See https://github.com/huggingface/transformers/pull/29402
|
311
|
+
normalizer = self.config.hidden_size**0.5
|
312
|
+
self.register_buffer("normalizer", torch.tensor(normalizer))
|
313
|
+
|
314
|
+
def forward(
|
315
|
+
self,
|
316
|
+
input_ids: torch.Tensor,
|
317
|
+
positions: torch.Tensor,
|
318
|
+
input_metadata: InputMetadata,
|
319
|
+
input_embeds: torch.Tensor = None,
|
320
|
+
) -> torch.Tensor:
|
321
|
+
if input_embeds is None:
|
322
|
+
hidden_states = self.embed_tokens(input_ids)
|
323
|
+
else:
|
324
|
+
hidden_states = input_embeds
|
325
|
+
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=torch.float16)
|
326
|
+
hidden_states *= normalizer
|
327
|
+
|
328
|
+
residual = None
|
329
|
+
for i in range(len(self.layers)):
|
330
|
+
layer = self.layers[i]
|
331
|
+
hidden_states, residual = layer(
|
332
|
+
positions,
|
333
|
+
hidden_states,
|
334
|
+
input_metadata,
|
335
|
+
residual,
|
336
|
+
)
|
337
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
338
|
+
return hidden_states
|
339
|
+
|
340
|
+
|
341
|
+
class Gemma2ForCausalLM(nn.Module):
|
342
|
+
packed_modules_mapping = {
|
343
|
+
"qkv_proj": [
|
344
|
+
"q_proj",
|
345
|
+
"k_proj",
|
346
|
+
"v_proj",
|
347
|
+
],
|
348
|
+
"gate_up_proj": [
|
349
|
+
"gate_proj",
|
350
|
+
"up_proj",
|
351
|
+
],
|
352
|
+
}
|
353
|
+
|
354
|
+
# LoRA specific attributes
|
355
|
+
supported_lora_modules = [
|
356
|
+
"qkv_proj",
|
357
|
+
"o_proj",
|
358
|
+
"gate_up_proj",
|
359
|
+
"down_proj",
|
360
|
+
]
|
361
|
+
# Gemma does not apply LoRA to the embedding layer.
|
362
|
+
embedding_modules = {}
|
363
|
+
embedding_padding_modules = []
|
364
|
+
|
365
|
+
def __init__(
|
366
|
+
self,
|
367
|
+
config: PretrainedConfig,
|
368
|
+
cache_config: Optional[CacheConfig] = None,
|
369
|
+
quant_config: Optional[QuantizationConfig] = None,
|
370
|
+
lora_config: Optional[LoRAConfig] = None,
|
371
|
+
) -> None:
|
372
|
+
del lora_config # Unused.
|
373
|
+
super().__init__()
|
374
|
+
self.config = config
|
375
|
+
self.quant_config = quant_config
|
376
|
+
self.model = Gemma2Model(config, cache_config, quant_config)
|
377
|
+
self.logits_processor = LogitsProcessor(config)
|
378
|
+
|
379
|
+
@torch.no_grad()
|
380
|
+
def forward(
|
381
|
+
self,
|
382
|
+
input_ids: torch.Tensor,
|
383
|
+
positions: torch.Tensor,
|
384
|
+
input_metadata: InputMetadata,
|
385
|
+
input_embeds: torch.Tensor = None,
|
386
|
+
) -> torch.Tensor:
|
387
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
388
|
+
return self.logits_processor(
|
389
|
+
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
390
|
+
)
|
391
|
+
|
392
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
393
|
+
stacked_params_mapping = [
|
394
|
+
# (param_name, shard_name, shard_id)
|
395
|
+
("qkv_proj", "q_proj", "q"),
|
396
|
+
("qkv_proj", "k_proj", "k"),
|
397
|
+
("qkv_proj", "v_proj", "v"),
|
398
|
+
("gate_up_proj", "gate_proj", 0),
|
399
|
+
("gate_up_proj", "up_proj", 1),
|
400
|
+
]
|
401
|
+
params_dict = dict(self.named_parameters())
|
402
|
+
loaded_params: Set[str] = set()
|
403
|
+
for name, loaded_weight in weights:
|
404
|
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
405
|
+
if shard_name not in name:
|
406
|
+
continue
|
407
|
+
name = name.replace(shard_name, param_name)
|
408
|
+
# Skip loading extra bias for GPTQ models.
|
409
|
+
if name.endswith(".bias") and name not in params_dict:
|
410
|
+
continue
|
411
|
+
param = params_dict[name]
|
412
|
+
weight_loader = param.weight_loader
|
413
|
+
weight_loader(param, loaded_weight, shard_id)
|
414
|
+
break
|
415
|
+
else:
|
416
|
+
# lm_head is not used in vllm as it is tied with embed_token.
|
417
|
+
# To prevent errors, skip loading lm_head.weight.
|
418
|
+
if "lm_head.weight" in name:
|
419
|
+
continue
|
420
|
+
# Skip loading extra bias for GPTQ models.
|
421
|
+
if name.endswith(".bias") and name not in params_dict:
|
422
|
+
continue
|
423
|
+
param = params_dict[name]
|
424
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
425
|
+
weight_loader(param, loaded_weight)
|
426
|
+
loaded_params.add(name)
|
427
|
+
|
428
|
+
unloaded_params = params_dict.keys() - loaded_params
|
429
|
+
if unloaded_params:
|
430
|
+
raise RuntimeError(
|
431
|
+
"Some weights are not initialized from checkpoints: "
|
432
|
+
f"{unloaded_params}"
|
433
|
+
)
|
434
|
+
|
435
|
+
|
436
|
+
EntryClass = Gemma2ForCausalLM
|