sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1009 @@
|
|
1
|
+
from typing import Iterable, Optional, Set, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn.functional as F
|
5
|
+
from torch import nn
|
6
|
+
from transformers import AutoModel, Gemma3nTextConfig, PretrainedConfig, PreTrainedModel
|
7
|
+
|
8
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
9
|
+
from sglang.srt.layers.activation import GeluAndMul
|
10
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
11
|
+
from sglang.srt.layers.linear import (
|
12
|
+
ColumnParallelLinear,
|
13
|
+
MergedColumnParallelLinear,
|
14
|
+
QKVParallelLinear,
|
15
|
+
RowParallelLinear,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
18
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
20
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
21
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
22
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
|
+
from sglang.srt.model_loader.weight_utils import (
|
24
|
+
default_weight_loader,
|
25
|
+
maybe_remap_kv_scale_name,
|
26
|
+
)
|
27
|
+
from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding
|
28
|
+
from sglang.srt.utils import add_prefix, make_layers
|
29
|
+
|
30
|
+
|
31
|
+
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
32
|
+
# SGLang assumes exclusive
|
33
|
+
def get_attention_sliding_window_size(config):
|
34
|
+
return config.sliding_window - 1
|
35
|
+
|
36
|
+
|
37
|
+
class Gemma3nRMSNorm(RMSNorm):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
dim: int,
|
41
|
+
eps: float = 1e-6,
|
42
|
+
with_scale: bool = True,
|
43
|
+
) -> None:
|
44
|
+
super().__init__(dim, eps=eps)
|
45
|
+
if not with_scale:
|
46
|
+
del self.weight
|
47
|
+
self.register_buffer(
|
48
|
+
"weight",
|
49
|
+
torch.ones(dim, dtype=torch.get_default_dtype()),
|
50
|
+
persistent=False,
|
51
|
+
)
|
52
|
+
|
53
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
54
|
+
original_shape = x.shape
|
55
|
+
x_2d = x.contiguous().reshape(-1, original_shape[-1])
|
56
|
+
x_2d = super().forward(x_2d)
|
57
|
+
x = x_2d.reshape(original_shape)
|
58
|
+
return x
|
59
|
+
|
60
|
+
|
61
|
+
class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
|
62
|
+
pass
|
63
|
+
|
64
|
+
|
65
|
+
class Gemma3nMLP(nn.Module):
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
hidden_size: int,
|
69
|
+
intermediate_size: int,
|
70
|
+
hidden_activation: str,
|
71
|
+
activation_sparsity: float = 0.0,
|
72
|
+
quant_config: Optional[QuantizationConfig] = None,
|
73
|
+
prefix: str = "",
|
74
|
+
) -> None:
|
75
|
+
super().__init__()
|
76
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
77
|
+
hidden_size,
|
78
|
+
[intermediate_size] * 2,
|
79
|
+
bias=False,
|
80
|
+
quant_config=quant_config,
|
81
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
82
|
+
)
|
83
|
+
self.down_proj = RowParallelLinear(
|
84
|
+
intermediate_size,
|
85
|
+
hidden_size,
|
86
|
+
bias=False,
|
87
|
+
quant_config=quant_config,
|
88
|
+
prefix=add_prefix("down_proj", prefix),
|
89
|
+
)
|
90
|
+
if hidden_activation != "gelu_pytorch_tanh":
|
91
|
+
raise ValueError(
|
92
|
+
"Gemma3n uses `gelu_pytorch_tanh` as the hidden activation "
|
93
|
+
"function. Please set `hidden_activation` to "
|
94
|
+
"`gelu_pytorch_tanh`."
|
95
|
+
)
|
96
|
+
# Use proper GELU with tanh approximation as specified
|
97
|
+
self.act_fn = GeluAndMul()
|
98
|
+
self.activation_sparsity = activation_sparsity
|
99
|
+
self.register_buffer(
|
100
|
+
"target_sparsity_tensor",
|
101
|
+
torch.tensor(self.activation_sparsity, dtype=torch.float32),
|
102
|
+
persistent=False,
|
103
|
+
) # moved from _gaussian_topk for cuda graph
|
104
|
+
|
105
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
106
|
+
gate_up, _ = self.gate_up_proj(x)
|
107
|
+
|
108
|
+
# Split gate and up projections
|
109
|
+
gate_proj, up_proj = gate_up.chunk(2, dim=-1)
|
110
|
+
|
111
|
+
# Apply activation sparsity if needed
|
112
|
+
if self.activation_sparsity > 0.0:
|
113
|
+
gate_proj = self._gaussian_topk(gate_proj)
|
114
|
+
|
115
|
+
gate_up = torch.cat([gate_proj, up_proj], dim=-1)
|
116
|
+
|
117
|
+
# Apply GELU activation to gate projection and multiply with up projection
|
118
|
+
x = self.act_fn(gate_up)
|
119
|
+
x, _ = self.down_proj(x)
|
120
|
+
return x
|
121
|
+
|
122
|
+
def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
|
123
|
+
normal_dist = torch.distributions.normal.Normal(0, 1)
|
124
|
+
std_multiplier = normal_dist.icdf(self.target_sparsity_tensor)
|
125
|
+
std_multiplier = std_multiplier.type(inputs.dtype)
|
126
|
+
inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
|
127
|
+
inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
|
128
|
+
cutoff_x = inputs_mean + inputs_std * std_multiplier
|
129
|
+
return F.relu(inputs - cutoff_x)
|
130
|
+
|
131
|
+
|
132
|
+
class Gemma3nLaurelBlock(nn.Module):
|
133
|
+
"""Learned Augmented Residual Layer"""
|
134
|
+
|
135
|
+
def __init__(
|
136
|
+
self,
|
137
|
+
config: Gemma3nTextConfig,
|
138
|
+
quant_config: Optional[QuantizationConfig] = None,
|
139
|
+
prefix: str = "",
|
140
|
+
):
|
141
|
+
super().__init__()
|
142
|
+
self.config = config
|
143
|
+
|
144
|
+
self.linear_left = ColumnParallelLinear(
|
145
|
+
config.hidden_size,
|
146
|
+
config.laurel_rank,
|
147
|
+
bias=False,
|
148
|
+
quant_config=quant_config,
|
149
|
+
prefix=add_prefix("linear_left", prefix),
|
150
|
+
)
|
151
|
+
self.linear_right = RowParallelLinear(
|
152
|
+
config.laurel_rank,
|
153
|
+
config.hidden_size,
|
154
|
+
bias=False,
|
155
|
+
quant_config=quant_config,
|
156
|
+
prefix=add_prefix("linear_right", prefix),
|
157
|
+
)
|
158
|
+
self.post_laurel_norm = Gemma3nRMSNorm(
|
159
|
+
dim=config.hidden_size,
|
160
|
+
eps=config.rms_norm_eps,
|
161
|
+
)
|
162
|
+
|
163
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
164
|
+
# [num_tokens, hidden_size]
|
165
|
+
laurel_x, _ = self.linear_left(x)
|
166
|
+
laurel_x, _ = self.linear_right(laurel_x)
|
167
|
+
normed_laurel_x = self.post_laurel_norm(laurel_x)
|
168
|
+
return x + normed_laurel_x
|
169
|
+
|
170
|
+
|
171
|
+
class Gemma3nAltUp(nn.Module):
|
172
|
+
"""Alternating Updates (AltUp)"""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
config: Gemma3nTextConfig,
|
177
|
+
quant_config: Optional[QuantizationConfig] = None,
|
178
|
+
prefix: str = "",
|
179
|
+
):
|
180
|
+
super().__init__()
|
181
|
+
self.config = config
|
182
|
+
|
183
|
+
self.correct_output_scale = nn.Parameter(
|
184
|
+
torch.zeros(config.hidden_size, dtype=torch.float32)
|
185
|
+
)
|
186
|
+
self.correction_coefs = ColumnParallelLinear(
|
187
|
+
config.altup_num_inputs,
|
188
|
+
config.altup_num_inputs,
|
189
|
+
bias=False,
|
190
|
+
quant_config=quant_config,
|
191
|
+
prefix=add_prefix("correction_coefs", prefix),
|
192
|
+
)
|
193
|
+
self.prediction_coefs = ColumnParallelLinear(
|
194
|
+
config.altup_num_inputs,
|
195
|
+
config.altup_num_inputs**2,
|
196
|
+
bias=False,
|
197
|
+
quant_config=quant_config,
|
198
|
+
prefix=add_prefix("prediction_coefs", prefix),
|
199
|
+
)
|
200
|
+
self.modality_router = ColumnParallelLinear(
|
201
|
+
config.hidden_size,
|
202
|
+
config.altup_num_inputs,
|
203
|
+
bias=False,
|
204
|
+
quant_config=quant_config,
|
205
|
+
prefix=add_prefix("modality_router", prefix),
|
206
|
+
)
|
207
|
+
|
208
|
+
self.router_norm = Gemma3nRMSNorm(
|
209
|
+
dim=config.hidden_size,
|
210
|
+
eps=config.rms_norm_eps,
|
211
|
+
)
|
212
|
+
|
213
|
+
self.register_buffer(
|
214
|
+
"router_input_scale",
|
215
|
+
torch.tensor(config.hidden_size**-1.0),
|
216
|
+
persistent=False,
|
217
|
+
)
|
218
|
+
|
219
|
+
def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
|
220
|
+
# x : [num_tokens, hidden_size]
|
221
|
+
router_inputs = self.router_norm(x) * self.router_input_scale.to(
|
222
|
+
self.router_norm.weight.dtype
|
223
|
+
)
|
224
|
+
# router_inputs : [num_tokens, hidden_size]
|
225
|
+
routed, _ = self.modality_router(router_inputs)
|
226
|
+
|
227
|
+
# routed : [num_tokens, altup_num_inputs]
|
228
|
+
return torch.tanh(routed.float()).type_as(routed)
|
229
|
+
|
230
|
+
def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
231
|
+
"""Predicts the output of a layer using a trainable map.
|
232
|
+
hidden_states: [num_altup_inputs, num_tokens, hidden_size]
|
233
|
+
"""
|
234
|
+
modalities = self.compute_router_modalities(
|
235
|
+
hidden_states[self.config.altup_active_idx]
|
236
|
+
) # (n_tokens, altup_num_inputs)
|
237
|
+
# TODO: CHECK DO WE NEED THIS: self.prediction_coefs.float() # Force computation in float32, in-place operation
|
238
|
+
|
239
|
+
if self.config.altup_coef_clip is not None:
|
240
|
+
self.prediction_coefs.weight.data.clamp_(
|
241
|
+
-self.config.altup_coef_clip, self.config.altup_coef_clip
|
242
|
+
)
|
243
|
+
|
244
|
+
all_coefs, _ = self.prediction_coefs(
|
245
|
+
modalities
|
246
|
+
) # (n_tokens, altup_num_inputs) -> (n_tokens, altup_num_inputs**2)
|
247
|
+
|
248
|
+
all_coefs = all_coefs.reshape(
|
249
|
+
*modalities.shape[:-1],
|
250
|
+
self.config.altup_num_inputs,
|
251
|
+
self.config.altup_num_inputs,
|
252
|
+
).permute(0, 2, 1)
|
253
|
+
|
254
|
+
# permute hidden_states from [num_altup_inputs, num_tokens, hidden_size] to [num_tokens, hidden_size, altup_num_inputs]
|
255
|
+
predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs)
|
256
|
+
predictions = predictions.permute(2, 0, 1) # undo the permute
|
257
|
+
predictions += hidden_states # add the original input
|
258
|
+
return predictions.contiguous().type_as(
|
259
|
+
hidden_states
|
260
|
+
) # [num_altup_inputs, num_tokens, hidden_size]
|
261
|
+
|
262
|
+
def correct(
|
263
|
+
self, predictions: torch.Tensor, activated: torch.Tensor
|
264
|
+
) -> torch.Tensor:
|
265
|
+
"""Corrects the predictions relative to the activated inputs."""
|
266
|
+
# prediction : [num_altup_inputs, num_tokens, hidden_size]
|
267
|
+
# activated : [num_tokens, hidden_size]
|
268
|
+
modalities = self.compute_router_modalities(
|
269
|
+
activated
|
270
|
+
) # [num_tokens, altup_num_inputs]
|
271
|
+
innovation = (
|
272
|
+
activated - predictions[self.config.altup_active_idx]
|
273
|
+
) # [num_tokens, hidden_size]
|
274
|
+
innovation = innovation.repeat(
|
275
|
+
self.config.altup_num_inputs, 1, 1
|
276
|
+
) # (self.config.altup_num_inputs, num_tokens, hidden_size)
|
277
|
+
|
278
|
+
if self.config.altup_coef_clip is not None:
|
279
|
+
self.correction_coefs.weight.data.clamp_(
|
280
|
+
-self.config.altup_coef_clip, self.config.altup_coef_clip
|
281
|
+
)
|
282
|
+
|
283
|
+
all_coefs, _ = self.correction_coefs(
|
284
|
+
modalities
|
285
|
+
) # [num_tokens, altup_num_inputs]
|
286
|
+
all_coefs = (all_coefs + 1.0).permute(1, 0).unsqueeze(-1)
|
287
|
+
# # [num_tokens, altup_num_inputs, 1]
|
288
|
+
|
289
|
+
corrected = torch.mul(innovation, all_coefs)
|
290
|
+
corrected += predictions
|
291
|
+
return corrected.contiguous().type_as(activated)
|
292
|
+
|
293
|
+
def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
|
294
|
+
"""Scales the provided 3D tensor."""
|
295
|
+
return corrected * self.correct_output_scale.to(corrected.dtype)
|
296
|
+
|
297
|
+
def forward(
|
298
|
+
self, hidden_states: torch.Tensor, activated: torch.Tensor
|
299
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
300
|
+
"""Predicts, correct, and optionally scales the output of a layer using trainable maps.
|
301
|
+
|
302
|
+
hidden_states: [num_altup_inputs, num_tokens, hidden_size]
|
303
|
+
"""
|
304
|
+
|
305
|
+
predictions = self.predict(hidden_states)
|
306
|
+
corrected = self.correct(predictions=predictions, activated=activated)
|
307
|
+
output = corrected[self.config.altup_active_idx]
|
308
|
+
if self.config.altup_correct_scale:
|
309
|
+
output = self.scale_corrected_output(output)
|
310
|
+
return corrected, output
|
311
|
+
|
312
|
+
|
313
|
+
class Gemma3nAttention(nn.Module):
|
314
|
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
315
|
+
|
316
|
+
def __init__(
|
317
|
+
self,
|
318
|
+
layer_id: int,
|
319
|
+
config: Gemma3nTextConfig,
|
320
|
+
max_position_embeddings: int,
|
321
|
+
quant_config: Optional[QuantizationConfig] = None,
|
322
|
+
prefix: str = "",
|
323
|
+
) -> None:
|
324
|
+
super().__init__()
|
325
|
+
self.layer_id = layer_id
|
326
|
+
self.config = config
|
327
|
+
tp_size = get_tensor_model_parallel_world_size()
|
328
|
+
|
329
|
+
self.total_num_heads = config.num_attention_heads
|
330
|
+
assert self.total_num_heads % tp_size == 0
|
331
|
+
self.num_heads = self.total_num_heads // tp_size
|
332
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
333
|
+
|
334
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
335
|
+
|
336
|
+
if self.total_num_kv_heads >= tp_size:
|
337
|
+
assert self.total_num_kv_heads % tp_size == 0
|
338
|
+
else:
|
339
|
+
assert tp_size % self.total_num_kv_heads == 0
|
340
|
+
|
341
|
+
hidden_size = config.hidden_size
|
342
|
+
head_dim = getattr(
|
343
|
+
config, "head_dim", hidden_size // config.num_attention_heads
|
344
|
+
)
|
345
|
+
self.head_dim = head_dim
|
346
|
+
|
347
|
+
self.q_size = self.num_heads * self.head_dim
|
348
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
349
|
+
# self.scaling = config.query_rescale_scalar / config.query_pre_attn_scalar
|
350
|
+
self.scaling = 1.0
|
351
|
+
|
352
|
+
self.qkv_proj = QKVParallelLinear(
|
353
|
+
hidden_size,
|
354
|
+
self.head_dim,
|
355
|
+
self.total_num_heads,
|
356
|
+
self.total_num_kv_heads,
|
357
|
+
bias=config.attention_bias,
|
358
|
+
quant_config=quant_config,
|
359
|
+
prefix=add_prefix("qkv_proj", prefix),
|
360
|
+
)
|
361
|
+
self.o_proj = RowParallelLinear(
|
362
|
+
self.total_num_heads * self.head_dim,
|
363
|
+
hidden_size,
|
364
|
+
bias=config.attention_bias,
|
365
|
+
quant_config=quant_config,
|
366
|
+
prefix=add_prefix("o_proj", prefix),
|
367
|
+
)
|
368
|
+
|
369
|
+
# Determine if layer uses sliding window based on pattern
|
370
|
+
self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
|
371
|
+
|
372
|
+
# Check if this is a KV shared layer
|
373
|
+
first_kv_shared_layer_idx = (
|
374
|
+
config.num_hidden_layers - config.num_kv_shared_layers
|
375
|
+
)
|
376
|
+
self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx
|
377
|
+
|
378
|
+
# Compute the layer index from which shared KV cache values will be retrieved
|
379
|
+
if not self.is_kv_shared_layer:
|
380
|
+
self.kv_shared_layer_index = None
|
381
|
+
elif self.is_sliding:
|
382
|
+
self.kv_shared_layer_index = first_kv_shared_layer_idx - 2
|
383
|
+
else:
|
384
|
+
self.kv_shared_layer_index = first_kv_shared_layer_idx - 1
|
385
|
+
|
386
|
+
if self.is_sliding:
|
387
|
+
self.rotary_emb = get_rope(
|
388
|
+
self.head_dim,
|
389
|
+
rotary_dim=self.head_dim,
|
390
|
+
max_position=config.max_position_embeddings,
|
391
|
+
base=config.rope_local_base_freq,
|
392
|
+
rope_scaling={"rope_type": "default"},
|
393
|
+
)
|
394
|
+
else:
|
395
|
+
self.rotary_emb = get_rope(
|
396
|
+
self.head_dim,
|
397
|
+
rotary_dim=self.head_dim,
|
398
|
+
max_position=config.max_position_embeddings,
|
399
|
+
base=config.rope_theta,
|
400
|
+
rope_scaling=config.rope_scaling,
|
401
|
+
)
|
402
|
+
|
403
|
+
self.sliding_window = config.sliding_window if self.is_sliding else None
|
404
|
+
|
405
|
+
self.attn = RadixAttention(
|
406
|
+
self.num_heads,
|
407
|
+
self.head_dim,
|
408
|
+
self.scaling,
|
409
|
+
num_kv_heads=self.num_kv_heads,
|
410
|
+
layer_id=(
|
411
|
+
layer_id if not self.is_kv_shared_layer else self.kv_shared_layer_index
|
412
|
+
),
|
413
|
+
logit_cap=0.0,
|
414
|
+
sliding_window_size=self.sliding_window,
|
415
|
+
quant_config=quant_config,
|
416
|
+
prefix=add_prefix("attn", prefix),
|
417
|
+
)
|
418
|
+
|
419
|
+
# Gemma3n adds normalization for q, k, v
|
420
|
+
self.q_norm = Gemma3nRMSNorm(
|
421
|
+
dim=config.head_dim,
|
422
|
+
eps=config.rms_norm_eps,
|
423
|
+
)
|
424
|
+
self.k_norm = Gemma3nRMSNorm(
|
425
|
+
dim=config.head_dim,
|
426
|
+
eps=config.rms_norm_eps,
|
427
|
+
)
|
428
|
+
self.v_norm = Gemma3nRMSNorm(
|
429
|
+
dim=config.head_dim,
|
430
|
+
eps=config.rms_norm_eps,
|
431
|
+
with_scale=False,
|
432
|
+
)
|
433
|
+
|
434
|
+
def forward(
|
435
|
+
self,
|
436
|
+
hidden_states: torch.Tensor,
|
437
|
+
positions: Tuple[torch.Tensor, torch.Tensor],
|
438
|
+
forward_batch: ForwardBatch,
|
439
|
+
**kwargs,
|
440
|
+
) -> torch.Tensor:
|
441
|
+
|
442
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
443
|
+
# TODO: for first 20 layers, we use QKVParallelLinear
|
444
|
+
# for others, we only calc Q.
|
445
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
446
|
+
|
447
|
+
# Apply normalization to q, k, v
|
448
|
+
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
449
|
+
q = self.q_norm(q)
|
450
|
+
|
451
|
+
# Check if we should use shared KV cache
|
452
|
+
if self.is_kv_shared_layer and self.kv_shared_layer_index is not None:
|
453
|
+
# For KV shared layers, we skip K/V computation and normalization
|
454
|
+
# The RadixAttention will handle retrieving shared KV from cache
|
455
|
+
k = None
|
456
|
+
v = None
|
457
|
+
else:
|
458
|
+
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
459
|
+
k = self.k_norm(k)
|
460
|
+
|
461
|
+
v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
462
|
+
v = self.v_norm(v)
|
463
|
+
|
464
|
+
# Flatten back for rotary embedding
|
465
|
+
q = q.flatten(-2, -1)
|
466
|
+
|
467
|
+
# Apply rotary embedding
|
468
|
+
if k is not None:
|
469
|
+
k = k.flatten(-2, -1)
|
470
|
+
q, k = self.rotary_emb(positions, q, k)
|
471
|
+
# Reshape k back to head format for attention
|
472
|
+
k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
|
473
|
+
else:
|
474
|
+
# For shared KV layers, create a dummy key for rotary embedding and discard it
|
475
|
+
dummy_k = torch.zeros_like(
|
476
|
+
q[:, : self.kv_size]
|
477
|
+
) # Create dummy key with same shape as needed
|
478
|
+
q, _ = self.rotary_emb(positions, q, dummy_k)
|
479
|
+
|
480
|
+
# Reshape q back to head format for attention
|
481
|
+
q = q.unflatten(-1, (self.num_heads, self.head_dim))
|
482
|
+
|
483
|
+
attn_output = self.attn(
|
484
|
+
q,
|
485
|
+
k,
|
486
|
+
v,
|
487
|
+
forward_batch=forward_batch,
|
488
|
+
save_kv_cache=not self.is_kv_shared_layer,
|
489
|
+
)
|
490
|
+
|
491
|
+
output, _ = self.o_proj(attn_output)
|
492
|
+
return output
|
493
|
+
|
494
|
+
|
495
|
+
class Gemma3nDecoderLayer(nn.Module):
|
496
|
+
def __init__(
|
497
|
+
self,
|
498
|
+
layer_id: int,
|
499
|
+
config: PretrainedConfig,
|
500
|
+
quant_config: Optional[QuantizationConfig] = None,
|
501
|
+
prefix: str = "",
|
502
|
+
) -> None:
|
503
|
+
super().__init__()
|
504
|
+
self.hidden_size = config.hidden_size
|
505
|
+
self.layer_id = layer_id
|
506
|
+
self.attention_type = config.layer_types[layer_id]
|
507
|
+
self.config = config
|
508
|
+
|
509
|
+
self.self_attn = Gemma3nAttention(
|
510
|
+
layer_id=layer_id,
|
511
|
+
config=config,
|
512
|
+
max_position_embeddings=config.max_position_embeddings,
|
513
|
+
quant_config=quant_config,
|
514
|
+
prefix=add_prefix("self_attn", prefix),
|
515
|
+
)
|
516
|
+
|
517
|
+
activation_sparsity = config.activation_sparsity_pattern[layer_id]
|
518
|
+
self.mlp = Gemma3nMLP(
|
519
|
+
hidden_size=self.hidden_size,
|
520
|
+
intermediate_size=config.intermediate_size,
|
521
|
+
hidden_activation=config.hidden_activation,
|
522
|
+
activation_sparsity=activation_sparsity,
|
523
|
+
quant_config=quant_config,
|
524
|
+
prefix=add_prefix("mlp", prefix),
|
525
|
+
)
|
526
|
+
|
527
|
+
self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
|
528
|
+
self.post_attention_layernorm = Gemma3nRMSNorm(
|
529
|
+
self.hidden_size, eps=config.rms_norm_eps
|
530
|
+
)
|
531
|
+
self.pre_feedforward_layernorm = Gemma3nRMSNorm(
|
532
|
+
self.hidden_size, eps=config.rms_norm_eps
|
533
|
+
)
|
534
|
+
self.post_feedforward_layernorm = Gemma3nRMSNorm(
|
535
|
+
self.hidden_size, eps=config.rms_norm_eps
|
536
|
+
)
|
537
|
+
|
538
|
+
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
|
539
|
+
|
540
|
+
self.altup = Gemma3nAltUp(
|
541
|
+
config, quant_config, prefix=add_prefix("altup", prefix)
|
542
|
+
)
|
543
|
+
self.laurel = Gemma3nLaurelBlock(
|
544
|
+
config, quant_config, prefix=add_prefix("laurel", prefix)
|
545
|
+
)
|
546
|
+
|
547
|
+
self.per_layer_input_gate = ColumnParallelLinear(
|
548
|
+
self.hidden_size,
|
549
|
+
self.hidden_size_per_layer_input,
|
550
|
+
bias=False,
|
551
|
+
quant_config=quant_config,
|
552
|
+
prefix=add_prefix("per_layer_input_gate", prefix),
|
553
|
+
)
|
554
|
+
self.per_layer_projection = RowParallelLinear(
|
555
|
+
self.hidden_size_per_layer_input,
|
556
|
+
self.hidden_size,
|
557
|
+
bias=False,
|
558
|
+
quant_config=quant_config,
|
559
|
+
prefix=add_prefix("per_layer_projection", prefix),
|
560
|
+
)
|
561
|
+
self.post_per_layer_input_norm = Gemma3nRMSNorm(
|
562
|
+
self.hidden_size, eps=config.rms_norm_eps
|
563
|
+
)
|
564
|
+
self.is_sliding = self.self_attn.is_sliding
|
565
|
+
|
566
|
+
def forward(
|
567
|
+
self,
|
568
|
+
positions: torch.Tensor,
|
569
|
+
hidden_states: torch.Tensor,
|
570
|
+
per_layer_input: torch.Tensor,
|
571
|
+
forward_batch: ForwardBatch,
|
572
|
+
**kwargs,
|
573
|
+
) -> torch.Tensor:
|
574
|
+
predictions = self.altup.predict(
|
575
|
+
hidden_states
|
576
|
+
) # [num_altup_inputs, num_tokens, hidden_size]
|
577
|
+
active_prediction = predictions[self.config.altup_active_idx]
|
578
|
+
|
579
|
+
active_prediction_normed = self.input_layernorm(active_prediction)
|
580
|
+
laurel_output = self.laurel(
|
581
|
+
active_prediction_normed
|
582
|
+
) # laurel_output: [num_tokens, hidden_size]
|
583
|
+
# active_prediction: [num_tokens, hidden_size]
|
584
|
+
|
585
|
+
attn = self.self_attn(
|
586
|
+
positions=positions,
|
587
|
+
hidden_states=active_prediction_normed,
|
588
|
+
forward_batch=forward_batch,
|
589
|
+
**kwargs,
|
590
|
+
)
|
591
|
+
attn = self.post_attention_layernorm(attn) # [num_tokens, hidden_size]
|
592
|
+
|
593
|
+
attn_gated = active_prediction + attn # [num_tokens, hidden_size]
|
594
|
+
attn_laurel = (attn_gated + laurel_output) / torch.sqrt(torch.tensor(2.0))
|
595
|
+
|
596
|
+
attn_norm = self.pre_feedforward_layernorm(
|
597
|
+
attn_laurel
|
598
|
+
) # [num_tokens, hidden_size]
|
599
|
+
attn_ffw = self.mlp(attn_norm) # [num_tokens, hidden_size]
|
600
|
+
attn_ffw_norm = self.post_feedforward_layernorm(
|
601
|
+
attn_ffw
|
602
|
+
) # [num_tokens, hidden_size]
|
603
|
+
attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm # [num_tokens, hidden_size]
|
604
|
+
corrected_predictions = self.altup.correct(
|
605
|
+
predictions, attn_ffw_laurel_gated
|
606
|
+
) # prediction : [num_altup_inputs, num_tokens, hidden_size]
|
607
|
+
# attn_ffw_laurel_gated: [num_tokens, hidden_size]
|
608
|
+
first_prediction = corrected_predictions[self.config.altup_active_idx]
|
609
|
+
|
610
|
+
if self.config.altup_correct_scale:
|
611
|
+
first_prediction = self.altup.scale_corrected_output(first_prediction)
|
612
|
+
|
613
|
+
# per_layer_input_gate
|
614
|
+
first_prediction = first_prediction.to(self.per_layer_input_gate.weight.dtype)
|
615
|
+
first_prediction, _ = self.per_layer_input_gate(first_prediction)
|
616
|
+
first_prediction = F.gelu(first_prediction, approximate="tanh")
|
617
|
+
first_prediction = torch.multiply(first_prediction, per_layer_input)
|
618
|
+
|
619
|
+
# per_layer_projection
|
620
|
+
first_prediction, _ = self.per_layer_projection(first_prediction)
|
621
|
+
first_prediction = self.post_per_layer_input_norm(first_prediction)
|
622
|
+
corrected_predictions[1:] += first_prediction
|
623
|
+
|
624
|
+
return corrected_predictions
|
625
|
+
|
626
|
+
|
627
|
+
class Gemma3nTextModel(PreTrainedModel):
|
628
|
+
def __init__(
|
629
|
+
self,
|
630
|
+
config: Gemma3nTextConfig,
|
631
|
+
quant_config: Optional[QuantizationConfig] = None,
|
632
|
+
prefix: str = "",
|
633
|
+
) -> None:
|
634
|
+
super().__init__(config=config)
|
635
|
+
self.config = config
|
636
|
+
self.quant_config = quant_config
|
637
|
+
self.vocab_size = config.vocab_size
|
638
|
+
self.padding_idx = config.pad_token_id
|
639
|
+
|
640
|
+
# Gemma3n downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
|
641
|
+
self.embed_tokens = Gemma3nTextScaledWordEmbedding(
|
642
|
+
config.vocab_size,
|
643
|
+
config.hidden_size,
|
644
|
+
self.padding_idx,
|
645
|
+
embed_scale=self.config.hidden_size**0.5,
|
646
|
+
)
|
647
|
+
|
648
|
+
self.norm = Gemma3nRMSNorm(
|
649
|
+
config.hidden_size,
|
650
|
+
eps=config.rms_norm_eps,
|
651
|
+
)
|
652
|
+
|
653
|
+
self.layers = make_layers(
|
654
|
+
config.num_hidden_layers,
|
655
|
+
lambda idx, prefix: Gemma3nDecoderLayer(
|
656
|
+
layer_id=idx,
|
657
|
+
config=config,
|
658
|
+
quant_config=quant_config,
|
659
|
+
prefix=prefix,
|
660
|
+
),
|
661
|
+
prefix=add_prefix("layers", prefix),
|
662
|
+
)
|
663
|
+
|
664
|
+
# Per-layer input embeddings
|
665
|
+
self.hidden_size = config.hidden_size
|
666
|
+
self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
|
667
|
+
|
668
|
+
self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
|
669
|
+
config.vocab_size_per_layer_input,
|
670
|
+
config.num_hidden_layers * config.hidden_size_per_layer_input,
|
671
|
+
self.padding_idx,
|
672
|
+
embed_scale=self.config.hidden_size_per_layer_input**0.5,
|
673
|
+
)
|
674
|
+
|
675
|
+
self.per_layer_model_projection = ColumnParallelLinear(
|
676
|
+
self.hidden_size,
|
677
|
+
config.num_hidden_layers * config.hidden_size_per_layer_input,
|
678
|
+
bias=False,
|
679
|
+
quant_config=quant_config,
|
680
|
+
prefix=add_prefix("per_layer_model_projection", prefix),
|
681
|
+
)
|
682
|
+
|
683
|
+
self.per_layer_projection_norm = Gemma3nRMSNorm(
|
684
|
+
dim=config.hidden_size_per_layer_input,
|
685
|
+
eps=config.rms_norm_eps,
|
686
|
+
)
|
687
|
+
|
688
|
+
self.altup_projections = make_layers(
|
689
|
+
self.config.altup_num_inputs - 1,
|
690
|
+
lambda idx, prefix: ColumnParallelLinear(
|
691
|
+
self.hidden_size,
|
692
|
+
self.hidden_size,
|
693
|
+
bias=False,
|
694
|
+
quant_config=quant_config,
|
695
|
+
prefix=prefix,
|
696
|
+
),
|
697
|
+
prefix=add_prefix("altup_projections", prefix),
|
698
|
+
)
|
699
|
+
|
700
|
+
self.altup_unembed_projections = make_layers(
|
701
|
+
self.config.altup_num_inputs - 1,
|
702
|
+
lambda idx, prefix: ColumnParallelLinear(
|
703
|
+
self.hidden_size,
|
704
|
+
self.hidden_size,
|
705
|
+
bias=False,
|
706
|
+
quant_config=quant_config,
|
707
|
+
prefix=prefix,
|
708
|
+
),
|
709
|
+
prefix=add_prefix("altup_unembed_projections", prefix),
|
710
|
+
)
|
711
|
+
|
712
|
+
self.register_buffer(
|
713
|
+
"per_layer_projection_scale",
|
714
|
+
torch.tensor(self.hidden_size**-0.5),
|
715
|
+
persistent=False,
|
716
|
+
)
|
717
|
+
self.register_buffer(
|
718
|
+
"per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False
|
719
|
+
)
|
720
|
+
|
721
|
+
self.post_init()
|
722
|
+
|
723
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
724
|
+
return self.embed_tokens
|
725
|
+
|
726
|
+
def dtype(self) -> torch.dtype:
|
727
|
+
return next(self.parameters()).dtype
|
728
|
+
|
729
|
+
def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
730
|
+
embeddings = self.embed_tokens_per_layer(input_ids)
|
731
|
+
return embeddings.reshape(
|
732
|
+
*input_ids.shape,
|
733
|
+
self.config.num_hidden_layers,
|
734
|
+
self.hidden_size_per_layer_input,
|
735
|
+
)
|
736
|
+
|
737
|
+
def project_per_layer_inputs(
|
738
|
+
self,
|
739
|
+
inputs_embeds: torch.Tensor,
|
740
|
+
per_layer_inputs: Optional[torch.Tensor] = None,
|
741
|
+
) -> torch.Tensor:
|
742
|
+
per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds)
|
743
|
+
per_layer_projection *= self.per_layer_projection_scale.type(
|
744
|
+
inputs_embeds.dtype
|
745
|
+
)
|
746
|
+
per_layer_projection = per_layer_projection.reshape(
|
747
|
+
*inputs_embeds.shape[:-1],
|
748
|
+
self.config.num_hidden_layers,
|
749
|
+
self.hidden_size_per_layer_input,
|
750
|
+
)
|
751
|
+
per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
|
752
|
+
|
753
|
+
if per_layer_inputs is None:
|
754
|
+
return per_layer_projection
|
755
|
+
|
756
|
+
if per_layer_projection.shape != per_layer_inputs.shape:
|
757
|
+
# per-layer inputs are sometimes padded with zeros, slice the relevant embeddings
|
758
|
+
per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
|
759
|
+
|
760
|
+
return (
|
761
|
+
per_layer_projection + per_layer_inputs
|
762
|
+
) * self.per_layer_input_scale.type(inputs_embeds.dtype)
|
763
|
+
|
764
|
+
def forward(
|
765
|
+
self,
|
766
|
+
input_ids: torch.Tensor,
|
767
|
+
positions: torch.Tensor,
|
768
|
+
forward_batch: ForwardBatch,
|
769
|
+
input_embeds: torch.Tensor = None,
|
770
|
+
per_layer_inputs: Optional[torch.Tensor] = None,
|
771
|
+
**kwargs,
|
772
|
+
) -> torch.Tensor:
|
773
|
+
if (input_ids is None) ^ (input_embeds is not None):
|
774
|
+
raise ValueError(
|
775
|
+
"You must specify exactly one of input_ids or inputs_embeds"
|
776
|
+
)
|
777
|
+
|
778
|
+
if input_ids is not None:
|
779
|
+
input_embeds = self.embed_tokens(input_ids)
|
780
|
+
per_layer_inputs = self.get_per_layer_inputs(input_ids)
|
781
|
+
|
782
|
+
per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs)
|
783
|
+
|
784
|
+
if positions.dim() == 1:
|
785
|
+
positions = positions.unsqueeze(0)
|
786
|
+
|
787
|
+
# Expand hidden_states to support per-layer inputs
|
788
|
+
target_magnitude = torch.mean(input_embeds**2, dim=-1, keepdim=True) ** 0.5
|
789
|
+
epsilon_tensor = torch.tensor(torch.finfo(input_embeds.dtype).min)
|
790
|
+
|
791
|
+
# embed positions
|
792
|
+
hidden_states_0 = input_embeds
|
793
|
+
temp_hidden_states = [hidden_states_0]
|
794
|
+
|
795
|
+
for i in range(1, self.config.altup_num_inputs):
|
796
|
+
altup_proj, _ = self.altup_projections[i - 1](hidden_states_0)
|
797
|
+
current_hidden_state = altup_proj.type(hidden_states_0.dtype)
|
798
|
+
new_magnitude = (
|
799
|
+
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
800
|
+
)
|
801
|
+
current_hidden_state = current_hidden_state * (
|
802
|
+
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
803
|
+
)
|
804
|
+
temp_hidden_states.append(current_hidden_state)
|
805
|
+
|
806
|
+
hidden_states = torch.stack(
|
807
|
+
temp_hidden_states, dim=0
|
808
|
+
) # [num_altup_inputs, n_tokens, hidden_size]
|
809
|
+
|
810
|
+
for layer_idx, layer in enumerate(self.layers):
|
811
|
+
per_layer_input = per_layer_inputs[:, layer_idx, :]
|
812
|
+
hidden_states = layer(
|
813
|
+
positions=positions,
|
814
|
+
per_layer_input=per_layer_input,
|
815
|
+
hidden_states=hidden_states,
|
816
|
+
forward_batch=forward_batch,
|
817
|
+
**kwargs,
|
818
|
+
)
|
819
|
+
|
820
|
+
# Per-layer inputs to single output
|
821
|
+
target_magnitude = (
|
822
|
+
torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
|
823
|
+
)
|
824
|
+
|
825
|
+
temp_hidden_states = [hidden_states[0]]
|
826
|
+
|
827
|
+
for i in range(1, self.config.altup_num_inputs):
|
828
|
+
# altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
|
829
|
+
altup_unemb_proj, _ = self.altup_unembed_projections[i - 1](
|
830
|
+
hidden_states[i]
|
831
|
+
)
|
832
|
+
current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
|
833
|
+
new_magnitude = (
|
834
|
+
torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
|
835
|
+
)
|
836
|
+
current_hidden_state = current_hidden_state * (
|
837
|
+
target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
|
838
|
+
)
|
839
|
+
temp_hidden_states.append(current_hidden_state)
|
840
|
+
|
841
|
+
hidden_states = torch.stack(temp_hidden_states)
|
842
|
+
hidden_states = torch.mean(hidden_states, dim=0)
|
843
|
+
hidden_states = self.norm(hidden_states)
|
844
|
+
|
845
|
+
return hidden_states
|
846
|
+
|
847
|
+
|
848
|
+
class Gemma3nForCausalLM(PreTrainedModel):
|
849
|
+
config_class = Gemma3nTextConfig
|
850
|
+
|
851
|
+
_tied_weights_keys = ["lm_head.weight"]
|
852
|
+
_tp_plan = {"lm_head": "colwise_rep"}
|
853
|
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
854
|
+
config_class = Gemma3nTextConfig
|
855
|
+
base_model_prefix = "language_model"
|
856
|
+
|
857
|
+
# BitandBytes specific attributes
|
858
|
+
default_bitsandbytes_target_modules = [
|
859
|
+
".gate_proj.",
|
860
|
+
".down_proj.",
|
861
|
+
".up_proj.",
|
862
|
+
".q_proj.",
|
863
|
+
".k_proj.",
|
864
|
+
".v_proj.",
|
865
|
+
".o_proj.",
|
866
|
+
]
|
867
|
+
bitsandbytes_stacked_params_mapping = {
|
868
|
+
".q_proj": (".qkv_proj", 0),
|
869
|
+
".k_proj": (".qkv_proj", 1),
|
870
|
+
".v_proj": (".qkv_proj", 2),
|
871
|
+
".gate_proj": (".gate_up_proj", 0),
|
872
|
+
".up_proj": (".gate_up_proj", 1),
|
873
|
+
}
|
874
|
+
|
875
|
+
packed_modules_mapping = {
|
876
|
+
".qkv_proj": [
|
877
|
+
".q_proj",
|
878
|
+
".k_proj",
|
879
|
+
".v_proj",
|
880
|
+
],
|
881
|
+
".gate_up_proj": [
|
882
|
+
".gate_proj",
|
883
|
+
".up_proj",
|
884
|
+
],
|
885
|
+
}
|
886
|
+
|
887
|
+
# LoRA specific attributes
|
888
|
+
supported_lora_modules = [
|
889
|
+
".qkv_proj",
|
890
|
+
".o_proj",
|
891
|
+
".gate_up_proj",
|
892
|
+
".down_proj",
|
893
|
+
]
|
894
|
+
# Gemma does not apply LoRA to the embedding layer
|
895
|
+
embedding_modules = {}
|
896
|
+
embedding_padding_modules = []
|
897
|
+
supports_lora = True
|
898
|
+
|
899
|
+
def __init__(
|
900
|
+
self,
|
901
|
+
config: Gemma3nTextConfig,
|
902
|
+
quant_config: Optional[QuantizationConfig] = None,
|
903
|
+
prefix: str = "",
|
904
|
+
) -> None:
|
905
|
+
super().__init__(config=config)
|
906
|
+
self.config = config
|
907
|
+
self.quant_config = quant_config
|
908
|
+
self.model = Gemma3nTextModel(
|
909
|
+
config=config,
|
910
|
+
quant_config=quant_config,
|
911
|
+
prefix=add_prefix("model", prefix),
|
912
|
+
)
|
913
|
+
self.logits_processor = LogitsProcessor(config)
|
914
|
+
|
915
|
+
if self.config.tie_word_embeddings:
|
916
|
+
self.lm_head = self.model.embed_tokens
|
917
|
+
else:
|
918
|
+
self.lm_head = ParallelLMHead(
|
919
|
+
config.vocab_size,
|
920
|
+
config.hidden_size,
|
921
|
+
quant_config=quant_config,
|
922
|
+
prefix=add_prefix("lm_head", prefix),
|
923
|
+
)
|
924
|
+
self.post_init()
|
925
|
+
|
926
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
927
|
+
return self.model.embed_tokens
|
928
|
+
|
929
|
+
def get_attention_sliding_window_size(self):
|
930
|
+
return get_attention_sliding_window_size(self.config)
|
931
|
+
|
932
|
+
def dtype(self) -> torch.dtype:
|
933
|
+
return next(self.parameters()).dtype
|
934
|
+
|
935
|
+
@torch.no_grad()
|
936
|
+
def forward(
|
937
|
+
self,
|
938
|
+
input_ids: torch.Tensor,
|
939
|
+
positions: torch.Tensor,
|
940
|
+
forward_batch: ForwardBatch,
|
941
|
+
input_embeds: torch.Tensor = None,
|
942
|
+
per_layer_inputs: Optional[torch.Tensor] = None,
|
943
|
+
**kwargs,
|
944
|
+
) -> LogitsProcessor:
|
945
|
+
hidden_states = self.model(
|
946
|
+
input_ids,
|
947
|
+
positions,
|
948
|
+
forward_batch,
|
949
|
+
input_embeds,
|
950
|
+
per_layer_inputs,
|
951
|
+
**kwargs,
|
952
|
+
)
|
953
|
+
|
954
|
+
return self.logits_processor(
|
955
|
+
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
956
|
+
)
|
957
|
+
|
958
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
959
|
+
stacked_params_mapping = [
|
960
|
+
# (param_name, shard_name, shard_id)
|
961
|
+
(".qkv_proj", ".q_proj", "q"),
|
962
|
+
(".qkv_proj", ".k_proj", "k"),
|
963
|
+
(".qkv_proj", ".v_proj", "v"),
|
964
|
+
(".gate_up_proj", ".gate_proj", 0),
|
965
|
+
(".gate_up_proj", ".up_proj", 1),
|
966
|
+
]
|
967
|
+
params_dict = dict(self.named_parameters())
|
968
|
+
loaded_params: Set[str] = set()
|
969
|
+
|
970
|
+
for name, loaded_weight in weights:
|
971
|
+
name = name.replace("model.language_model.", "model.")
|
972
|
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
973
|
+
if shard_name not in name:
|
974
|
+
continue
|
975
|
+
name = name.replace(shard_name, param_name)
|
976
|
+
# Skip loading extra bias for GPTQ models
|
977
|
+
if name.endswith(".bias") and name not in params_dict:
|
978
|
+
continue
|
979
|
+
if name not in params_dict:
|
980
|
+
# Skip loading weights that are not in the model
|
981
|
+
continue
|
982
|
+
param = params_dict[name]
|
983
|
+
weight_loader = param.weight_loader
|
984
|
+
weight_loader(param, loaded_weight, shard_id)
|
985
|
+
break
|
986
|
+
else:
|
987
|
+
# lm_head is not used in vllm as it is tied with embed_token
|
988
|
+
if "lm_head.weight" in name:
|
989
|
+
continue
|
990
|
+
# Skip loading extra bias for GPTQ models
|
991
|
+
if name.endswith(".bias") and name not in params_dict:
|
992
|
+
continue
|
993
|
+
# Remapping the name of FP8 kv-scale
|
994
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
995
|
+
if name is None:
|
996
|
+
continue
|
997
|
+
if name not in params_dict:
|
998
|
+
# Skip loading weights that are not in the model
|
999
|
+
continue
|
1000
|
+
|
1001
|
+
param = params_dict[name]
|
1002
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1003
|
+
weight_loader(param, loaded_weight)
|
1004
|
+
loaded_params.add(name)
|
1005
|
+
return loaded_params
|
1006
|
+
|
1007
|
+
|
1008
|
+
EntryClass = Gemma3nForCausalLM
|
1009
|
+
AutoModel.register(Gemma3nTextConfig, Gemma3nForCausalLM, exist_ok=True)
|