sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1009 @@
1
+ from typing import Iterable, Optional, Set, Tuple
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from transformers import AutoModel, Gemma3nTextConfig, PretrainedConfig, PreTrainedModel
7
+
8
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
9
+ from sglang.srt.layers.activation import GeluAndMul
10
+ from sglang.srt.layers.layernorm import RMSNorm
11
+ from sglang.srt.layers.linear import (
12
+ ColumnParallelLinear,
13
+ MergedColumnParallelLinear,
14
+ QKVParallelLinear,
15
+ RowParallelLinear,
16
+ )
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
+ from sglang.srt.layers.radix_attention import RadixAttention
20
+ from sglang.srt.layers.rotary_embedding import get_rope
21
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
+ from sglang.srt.model_loader.weight_utils import (
24
+ default_weight_loader,
25
+ maybe_remap_kv_scale_name,
26
+ )
27
+ from sglang.srt.models.gemma3_causal import Gemma3TextScaledWordEmbedding
28
+ from sglang.srt.utils import add_prefix, make_layers
29
+
30
+
31
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
32
+ # SGLang assumes exclusive
33
+ def get_attention_sliding_window_size(config):
34
+ return config.sliding_window - 1
35
+
36
+
37
+ class Gemma3nRMSNorm(RMSNorm):
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ eps: float = 1e-6,
42
+ with_scale: bool = True,
43
+ ) -> None:
44
+ super().__init__(dim, eps=eps)
45
+ if not with_scale:
46
+ del self.weight
47
+ self.register_buffer(
48
+ "weight",
49
+ torch.ones(dim, dtype=torch.get_default_dtype()),
50
+ persistent=False,
51
+ )
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ original_shape = x.shape
55
+ x_2d = x.contiguous().reshape(-1, original_shape[-1])
56
+ x_2d = super().forward(x_2d)
57
+ x = x_2d.reshape(original_shape)
58
+ return x
59
+
60
+
61
+ class Gemma3nTextScaledWordEmbedding(Gemma3TextScaledWordEmbedding):
62
+ pass
63
+
64
+
65
+ class Gemma3nMLP(nn.Module):
66
+ def __init__(
67
+ self,
68
+ hidden_size: int,
69
+ intermediate_size: int,
70
+ hidden_activation: str,
71
+ activation_sparsity: float = 0.0,
72
+ quant_config: Optional[QuantizationConfig] = None,
73
+ prefix: str = "",
74
+ ) -> None:
75
+ super().__init__()
76
+ self.gate_up_proj = MergedColumnParallelLinear(
77
+ hidden_size,
78
+ [intermediate_size] * 2,
79
+ bias=False,
80
+ quant_config=quant_config,
81
+ prefix=add_prefix("gate_up_proj", prefix),
82
+ )
83
+ self.down_proj = RowParallelLinear(
84
+ intermediate_size,
85
+ hidden_size,
86
+ bias=False,
87
+ quant_config=quant_config,
88
+ prefix=add_prefix("down_proj", prefix),
89
+ )
90
+ if hidden_activation != "gelu_pytorch_tanh":
91
+ raise ValueError(
92
+ "Gemma3n uses `gelu_pytorch_tanh` as the hidden activation "
93
+ "function. Please set `hidden_activation` to "
94
+ "`gelu_pytorch_tanh`."
95
+ )
96
+ # Use proper GELU with tanh approximation as specified
97
+ self.act_fn = GeluAndMul()
98
+ self.activation_sparsity = activation_sparsity
99
+ self.register_buffer(
100
+ "target_sparsity_tensor",
101
+ torch.tensor(self.activation_sparsity, dtype=torch.float32),
102
+ persistent=False,
103
+ ) # moved from _gaussian_topk for cuda graph
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ gate_up, _ = self.gate_up_proj(x)
107
+
108
+ # Split gate and up projections
109
+ gate_proj, up_proj = gate_up.chunk(2, dim=-1)
110
+
111
+ # Apply activation sparsity if needed
112
+ if self.activation_sparsity > 0.0:
113
+ gate_proj = self._gaussian_topk(gate_proj)
114
+
115
+ gate_up = torch.cat([gate_proj, up_proj], dim=-1)
116
+
117
+ # Apply GELU activation to gate projection and multiply with up projection
118
+ x = self.act_fn(gate_up)
119
+ x, _ = self.down_proj(x)
120
+ return x
121
+
122
+ def _gaussian_topk(self, inputs: torch.Tensor) -> torch.Tensor:
123
+ normal_dist = torch.distributions.normal.Normal(0, 1)
124
+ std_multiplier = normal_dist.icdf(self.target_sparsity_tensor)
125
+ std_multiplier = std_multiplier.type(inputs.dtype)
126
+ inputs_mean = torch.mean(inputs, dim=-1, keepdim=True)
127
+ inputs_std = torch.std(inputs, dim=-1, keepdim=True, unbiased=False)
128
+ cutoff_x = inputs_mean + inputs_std * std_multiplier
129
+ return F.relu(inputs - cutoff_x)
130
+
131
+
132
+ class Gemma3nLaurelBlock(nn.Module):
133
+ """Learned Augmented Residual Layer"""
134
+
135
+ def __init__(
136
+ self,
137
+ config: Gemma3nTextConfig,
138
+ quant_config: Optional[QuantizationConfig] = None,
139
+ prefix: str = "",
140
+ ):
141
+ super().__init__()
142
+ self.config = config
143
+
144
+ self.linear_left = ColumnParallelLinear(
145
+ config.hidden_size,
146
+ config.laurel_rank,
147
+ bias=False,
148
+ quant_config=quant_config,
149
+ prefix=add_prefix("linear_left", prefix),
150
+ )
151
+ self.linear_right = RowParallelLinear(
152
+ config.laurel_rank,
153
+ config.hidden_size,
154
+ bias=False,
155
+ quant_config=quant_config,
156
+ prefix=add_prefix("linear_right", prefix),
157
+ )
158
+ self.post_laurel_norm = Gemma3nRMSNorm(
159
+ dim=config.hidden_size,
160
+ eps=config.rms_norm_eps,
161
+ )
162
+
163
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
164
+ # [num_tokens, hidden_size]
165
+ laurel_x, _ = self.linear_left(x)
166
+ laurel_x, _ = self.linear_right(laurel_x)
167
+ normed_laurel_x = self.post_laurel_norm(laurel_x)
168
+ return x + normed_laurel_x
169
+
170
+
171
+ class Gemma3nAltUp(nn.Module):
172
+ """Alternating Updates (AltUp)"""
173
+
174
+ def __init__(
175
+ self,
176
+ config: Gemma3nTextConfig,
177
+ quant_config: Optional[QuantizationConfig] = None,
178
+ prefix: str = "",
179
+ ):
180
+ super().__init__()
181
+ self.config = config
182
+
183
+ self.correct_output_scale = nn.Parameter(
184
+ torch.zeros(config.hidden_size, dtype=torch.float32)
185
+ )
186
+ self.correction_coefs = ColumnParallelLinear(
187
+ config.altup_num_inputs,
188
+ config.altup_num_inputs,
189
+ bias=False,
190
+ quant_config=quant_config,
191
+ prefix=add_prefix("correction_coefs", prefix),
192
+ )
193
+ self.prediction_coefs = ColumnParallelLinear(
194
+ config.altup_num_inputs,
195
+ config.altup_num_inputs**2,
196
+ bias=False,
197
+ quant_config=quant_config,
198
+ prefix=add_prefix("prediction_coefs", prefix),
199
+ )
200
+ self.modality_router = ColumnParallelLinear(
201
+ config.hidden_size,
202
+ config.altup_num_inputs,
203
+ bias=False,
204
+ quant_config=quant_config,
205
+ prefix=add_prefix("modality_router", prefix),
206
+ )
207
+
208
+ self.router_norm = Gemma3nRMSNorm(
209
+ dim=config.hidden_size,
210
+ eps=config.rms_norm_eps,
211
+ )
212
+
213
+ self.register_buffer(
214
+ "router_input_scale",
215
+ torch.tensor(config.hidden_size**-1.0),
216
+ persistent=False,
217
+ )
218
+
219
+ def compute_router_modalities(self, x: torch.Tensor) -> torch.Tensor:
220
+ # x : [num_tokens, hidden_size]
221
+ router_inputs = self.router_norm(x) * self.router_input_scale.to(
222
+ self.router_norm.weight.dtype
223
+ )
224
+ # router_inputs : [num_tokens, hidden_size]
225
+ routed, _ = self.modality_router(router_inputs)
226
+
227
+ # routed : [num_tokens, altup_num_inputs]
228
+ return torch.tanh(routed.float()).type_as(routed)
229
+
230
+ def predict(self, hidden_states: torch.Tensor) -> torch.Tensor:
231
+ """Predicts the output of a layer using a trainable map.
232
+ hidden_states: [num_altup_inputs, num_tokens, hidden_size]
233
+ """
234
+ modalities = self.compute_router_modalities(
235
+ hidden_states[self.config.altup_active_idx]
236
+ ) # (n_tokens, altup_num_inputs)
237
+ # TODO: CHECK DO WE NEED THIS: self.prediction_coefs.float() # Force computation in float32, in-place operation
238
+
239
+ if self.config.altup_coef_clip is not None:
240
+ self.prediction_coefs.weight.data.clamp_(
241
+ -self.config.altup_coef_clip, self.config.altup_coef_clip
242
+ )
243
+
244
+ all_coefs, _ = self.prediction_coefs(
245
+ modalities
246
+ ) # (n_tokens, altup_num_inputs) -> (n_tokens, altup_num_inputs**2)
247
+
248
+ all_coefs = all_coefs.reshape(
249
+ *modalities.shape[:-1],
250
+ self.config.altup_num_inputs,
251
+ self.config.altup_num_inputs,
252
+ ).permute(0, 2, 1)
253
+
254
+ # permute hidden_states from [num_altup_inputs, num_tokens, hidden_size] to [num_tokens, hidden_size, altup_num_inputs]
255
+ predictions = torch.matmul(hidden_states.permute(1, 2, 0), all_coefs)
256
+ predictions = predictions.permute(2, 0, 1) # undo the permute
257
+ predictions += hidden_states # add the original input
258
+ return predictions.contiguous().type_as(
259
+ hidden_states
260
+ ) # [num_altup_inputs, num_tokens, hidden_size]
261
+
262
+ def correct(
263
+ self, predictions: torch.Tensor, activated: torch.Tensor
264
+ ) -> torch.Tensor:
265
+ """Corrects the predictions relative to the activated inputs."""
266
+ # prediction : [num_altup_inputs, num_tokens, hidden_size]
267
+ # activated : [num_tokens, hidden_size]
268
+ modalities = self.compute_router_modalities(
269
+ activated
270
+ ) # [num_tokens, altup_num_inputs]
271
+ innovation = (
272
+ activated - predictions[self.config.altup_active_idx]
273
+ ) # [num_tokens, hidden_size]
274
+ innovation = innovation.repeat(
275
+ self.config.altup_num_inputs, 1, 1
276
+ ) # (self.config.altup_num_inputs, num_tokens, hidden_size)
277
+
278
+ if self.config.altup_coef_clip is not None:
279
+ self.correction_coefs.weight.data.clamp_(
280
+ -self.config.altup_coef_clip, self.config.altup_coef_clip
281
+ )
282
+
283
+ all_coefs, _ = self.correction_coefs(
284
+ modalities
285
+ ) # [num_tokens, altup_num_inputs]
286
+ all_coefs = (all_coefs + 1.0).permute(1, 0).unsqueeze(-1)
287
+ # # [num_tokens, altup_num_inputs, 1]
288
+
289
+ corrected = torch.mul(innovation, all_coefs)
290
+ corrected += predictions
291
+ return corrected.contiguous().type_as(activated)
292
+
293
+ def scale_corrected_output(self, corrected: torch.Tensor) -> torch.Tensor:
294
+ """Scales the provided 3D tensor."""
295
+ return corrected * self.correct_output_scale.to(corrected.dtype)
296
+
297
+ def forward(
298
+ self, hidden_states: torch.Tensor, activated: torch.Tensor
299
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
300
+ """Predicts, correct, and optionally scales the output of a layer using trainable maps.
301
+
302
+ hidden_states: [num_altup_inputs, num_tokens, hidden_size]
303
+ """
304
+
305
+ predictions = self.predict(hidden_states)
306
+ corrected = self.correct(predictions=predictions, activated=activated)
307
+ output = corrected[self.config.altup_active_idx]
308
+ if self.config.altup_correct_scale:
309
+ output = self.scale_corrected_output(output)
310
+ return corrected, output
311
+
312
+
313
+ class Gemma3nAttention(nn.Module):
314
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
315
+
316
+ def __init__(
317
+ self,
318
+ layer_id: int,
319
+ config: Gemma3nTextConfig,
320
+ max_position_embeddings: int,
321
+ quant_config: Optional[QuantizationConfig] = None,
322
+ prefix: str = "",
323
+ ) -> None:
324
+ super().__init__()
325
+ self.layer_id = layer_id
326
+ self.config = config
327
+ tp_size = get_tensor_model_parallel_world_size()
328
+
329
+ self.total_num_heads = config.num_attention_heads
330
+ assert self.total_num_heads % tp_size == 0
331
+ self.num_heads = self.total_num_heads // tp_size
332
+ self.total_num_kv_heads = config.num_key_value_heads
333
+
334
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
335
+
336
+ if self.total_num_kv_heads >= tp_size:
337
+ assert self.total_num_kv_heads % tp_size == 0
338
+ else:
339
+ assert tp_size % self.total_num_kv_heads == 0
340
+
341
+ hidden_size = config.hidden_size
342
+ head_dim = getattr(
343
+ config, "head_dim", hidden_size // config.num_attention_heads
344
+ )
345
+ self.head_dim = head_dim
346
+
347
+ self.q_size = self.num_heads * self.head_dim
348
+ self.kv_size = self.num_kv_heads * self.head_dim
349
+ # self.scaling = config.query_rescale_scalar / config.query_pre_attn_scalar
350
+ self.scaling = 1.0
351
+
352
+ self.qkv_proj = QKVParallelLinear(
353
+ hidden_size,
354
+ self.head_dim,
355
+ self.total_num_heads,
356
+ self.total_num_kv_heads,
357
+ bias=config.attention_bias,
358
+ quant_config=quant_config,
359
+ prefix=add_prefix("qkv_proj", prefix),
360
+ )
361
+ self.o_proj = RowParallelLinear(
362
+ self.total_num_heads * self.head_dim,
363
+ hidden_size,
364
+ bias=config.attention_bias,
365
+ quant_config=quant_config,
366
+ prefix=add_prefix("o_proj", prefix),
367
+ )
368
+
369
+ # Determine if layer uses sliding window based on pattern
370
+ self.is_sliding = config.layer_types[layer_id] == "sliding_attention"
371
+
372
+ # Check if this is a KV shared layer
373
+ first_kv_shared_layer_idx = (
374
+ config.num_hidden_layers - config.num_kv_shared_layers
375
+ )
376
+ self.is_kv_shared_layer = layer_id >= first_kv_shared_layer_idx
377
+
378
+ # Compute the layer index from which shared KV cache values will be retrieved
379
+ if not self.is_kv_shared_layer:
380
+ self.kv_shared_layer_index = None
381
+ elif self.is_sliding:
382
+ self.kv_shared_layer_index = first_kv_shared_layer_idx - 2
383
+ else:
384
+ self.kv_shared_layer_index = first_kv_shared_layer_idx - 1
385
+
386
+ if self.is_sliding:
387
+ self.rotary_emb = get_rope(
388
+ self.head_dim,
389
+ rotary_dim=self.head_dim,
390
+ max_position=config.max_position_embeddings,
391
+ base=config.rope_local_base_freq,
392
+ rope_scaling={"rope_type": "default"},
393
+ )
394
+ else:
395
+ self.rotary_emb = get_rope(
396
+ self.head_dim,
397
+ rotary_dim=self.head_dim,
398
+ max_position=config.max_position_embeddings,
399
+ base=config.rope_theta,
400
+ rope_scaling=config.rope_scaling,
401
+ )
402
+
403
+ self.sliding_window = config.sliding_window if self.is_sliding else None
404
+
405
+ self.attn = RadixAttention(
406
+ self.num_heads,
407
+ self.head_dim,
408
+ self.scaling,
409
+ num_kv_heads=self.num_kv_heads,
410
+ layer_id=(
411
+ layer_id if not self.is_kv_shared_layer else self.kv_shared_layer_index
412
+ ),
413
+ logit_cap=0.0,
414
+ sliding_window_size=self.sliding_window,
415
+ quant_config=quant_config,
416
+ prefix=add_prefix("attn", prefix),
417
+ )
418
+
419
+ # Gemma3n adds normalization for q, k, v
420
+ self.q_norm = Gemma3nRMSNorm(
421
+ dim=config.head_dim,
422
+ eps=config.rms_norm_eps,
423
+ )
424
+ self.k_norm = Gemma3nRMSNorm(
425
+ dim=config.head_dim,
426
+ eps=config.rms_norm_eps,
427
+ )
428
+ self.v_norm = Gemma3nRMSNorm(
429
+ dim=config.head_dim,
430
+ eps=config.rms_norm_eps,
431
+ with_scale=False,
432
+ )
433
+
434
+ def forward(
435
+ self,
436
+ hidden_states: torch.Tensor,
437
+ positions: Tuple[torch.Tensor, torch.Tensor],
438
+ forward_batch: ForwardBatch,
439
+ **kwargs,
440
+ ) -> torch.Tensor:
441
+
442
+ qkv, _ = self.qkv_proj(hidden_states)
443
+ # TODO: for first 20 layers, we use QKVParallelLinear
444
+ # for others, we only calc Q.
445
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
446
+
447
+ # Apply normalization to q, k, v
448
+ q = q.unflatten(-1, (self.num_heads, self.head_dim))
449
+ q = self.q_norm(q)
450
+
451
+ # Check if we should use shared KV cache
452
+ if self.is_kv_shared_layer and self.kv_shared_layer_index is not None:
453
+ # For KV shared layers, we skip K/V computation and normalization
454
+ # The RadixAttention will handle retrieving shared KV from cache
455
+ k = None
456
+ v = None
457
+ else:
458
+ k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
459
+ k = self.k_norm(k)
460
+
461
+ v = v.unflatten(-1, (self.num_kv_heads, self.head_dim))
462
+ v = self.v_norm(v)
463
+
464
+ # Flatten back for rotary embedding
465
+ q = q.flatten(-2, -1)
466
+
467
+ # Apply rotary embedding
468
+ if k is not None:
469
+ k = k.flatten(-2, -1)
470
+ q, k = self.rotary_emb(positions, q, k)
471
+ # Reshape k back to head format for attention
472
+ k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
473
+ else:
474
+ # For shared KV layers, create a dummy key for rotary embedding and discard it
475
+ dummy_k = torch.zeros_like(
476
+ q[:, : self.kv_size]
477
+ ) # Create dummy key with same shape as needed
478
+ q, _ = self.rotary_emb(positions, q, dummy_k)
479
+
480
+ # Reshape q back to head format for attention
481
+ q = q.unflatten(-1, (self.num_heads, self.head_dim))
482
+
483
+ attn_output = self.attn(
484
+ q,
485
+ k,
486
+ v,
487
+ forward_batch=forward_batch,
488
+ save_kv_cache=not self.is_kv_shared_layer,
489
+ )
490
+
491
+ output, _ = self.o_proj(attn_output)
492
+ return output
493
+
494
+
495
+ class Gemma3nDecoderLayer(nn.Module):
496
+ def __init__(
497
+ self,
498
+ layer_id: int,
499
+ config: PretrainedConfig,
500
+ quant_config: Optional[QuantizationConfig] = None,
501
+ prefix: str = "",
502
+ ) -> None:
503
+ super().__init__()
504
+ self.hidden_size = config.hidden_size
505
+ self.layer_id = layer_id
506
+ self.attention_type = config.layer_types[layer_id]
507
+ self.config = config
508
+
509
+ self.self_attn = Gemma3nAttention(
510
+ layer_id=layer_id,
511
+ config=config,
512
+ max_position_embeddings=config.max_position_embeddings,
513
+ quant_config=quant_config,
514
+ prefix=add_prefix("self_attn", prefix),
515
+ )
516
+
517
+ activation_sparsity = config.activation_sparsity_pattern[layer_id]
518
+ self.mlp = Gemma3nMLP(
519
+ hidden_size=self.hidden_size,
520
+ intermediate_size=config.intermediate_size,
521
+ hidden_activation=config.hidden_activation,
522
+ activation_sparsity=activation_sparsity,
523
+ quant_config=quant_config,
524
+ prefix=add_prefix("mlp", prefix),
525
+ )
526
+
527
+ self.input_layernorm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
528
+ self.post_attention_layernorm = Gemma3nRMSNorm(
529
+ self.hidden_size, eps=config.rms_norm_eps
530
+ )
531
+ self.pre_feedforward_layernorm = Gemma3nRMSNorm(
532
+ self.hidden_size, eps=config.rms_norm_eps
533
+ )
534
+ self.post_feedforward_layernorm = Gemma3nRMSNorm(
535
+ self.hidden_size, eps=config.rms_norm_eps
536
+ )
537
+
538
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
539
+
540
+ self.altup = Gemma3nAltUp(
541
+ config, quant_config, prefix=add_prefix("altup", prefix)
542
+ )
543
+ self.laurel = Gemma3nLaurelBlock(
544
+ config, quant_config, prefix=add_prefix("laurel", prefix)
545
+ )
546
+
547
+ self.per_layer_input_gate = ColumnParallelLinear(
548
+ self.hidden_size,
549
+ self.hidden_size_per_layer_input,
550
+ bias=False,
551
+ quant_config=quant_config,
552
+ prefix=add_prefix("per_layer_input_gate", prefix),
553
+ )
554
+ self.per_layer_projection = RowParallelLinear(
555
+ self.hidden_size_per_layer_input,
556
+ self.hidden_size,
557
+ bias=False,
558
+ quant_config=quant_config,
559
+ prefix=add_prefix("per_layer_projection", prefix),
560
+ )
561
+ self.post_per_layer_input_norm = Gemma3nRMSNorm(
562
+ self.hidden_size, eps=config.rms_norm_eps
563
+ )
564
+ self.is_sliding = self.self_attn.is_sliding
565
+
566
+ def forward(
567
+ self,
568
+ positions: torch.Tensor,
569
+ hidden_states: torch.Tensor,
570
+ per_layer_input: torch.Tensor,
571
+ forward_batch: ForwardBatch,
572
+ **kwargs,
573
+ ) -> torch.Tensor:
574
+ predictions = self.altup.predict(
575
+ hidden_states
576
+ ) # [num_altup_inputs, num_tokens, hidden_size]
577
+ active_prediction = predictions[self.config.altup_active_idx]
578
+
579
+ active_prediction_normed = self.input_layernorm(active_prediction)
580
+ laurel_output = self.laurel(
581
+ active_prediction_normed
582
+ ) # laurel_output: [num_tokens, hidden_size]
583
+ # active_prediction: [num_tokens, hidden_size]
584
+
585
+ attn = self.self_attn(
586
+ positions=positions,
587
+ hidden_states=active_prediction_normed,
588
+ forward_batch=forward_batch,
589
+ **kwargs,
590
+ )
591
+ attn = self.post_attention_layernorm(attn) # [num_tokens, hidden_size]
592
+
593
+ attn_gated = active_prediction + attn # [num_tokens, hidden_size]
594
+ attn_laurel = (attn_gated + laurel_output) / torch.sqrt(torch.tensor(2.0))
595
+
596
+ attn_norm = self.pre_feedforward_layernorm(
597
+ attn_laurel
598
+ ) # [num_tokens, hidden_size]
599
+ attn_ffw = self.mlp(attn_norm) # [num_tokens, hidden_size]
600
+ attn_ffw_norm = self.post_feedforward_layernorm(
601
+ attn_ffw
602
+ ) # [num_tokens, hidden_size]
603
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm # [num_tokens, hidden_size]
604
+ corrected_predictions = self.altup.correct(
605
+ predictions, attn_ffw_laurel_gated
606
+ ) # prediction : [num_altup_inputs, num_tokens, hidden_size]
607
+ # attn_ffw_laurel_gated: [num_tokens, hidden_size]
608
+ first_prediction = corrected_predictions[self.config.altup_active_idx]
609
+
610
+ if self.config.altup_correct_scale:
611
+ first_prediction = self.altup.scale_corrected_output(first_prediction)
612
+
613
+ # per_layer_input_gate
614
+ first_prediction = first_prediction.to(self.per_layer_input_gate.weight.dtype)
615
+ first_prediction, _ = self.per_layer_input_gate(first_prediction)
616
+ first_prediction = F.gelu(first_prediction, approximate="tanh")
617
+ first_prediction = torch.multiply(first_prediction, per_layer_input)
618
+
619
+ # per_layer_projection
620
+ first_prediction, _ = self.per_layer_projection(first_prediction)
621
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
622
+ corrected_predictions[1:] += first_prediction
623
+
624
+ return corrected_predictions
625
+
626
+
627
+ class Gemma3nTextModel(PreTrainedModel):
628
+ def __init__(
629
+ self,
630
+ config: Gemma3nTextConfig,
631
+ quant_config: Optional[QuantizationConfig] = None,
632
+ prefix: str = "",
633
+ ) -> None:
634
+ super().__init__(config=config)
635
+ self.config = config
636
+ self.quant_config = quant_config
637
+ self.vocab_size = config.vocab_size
638
+ self.padding_idx = config.pad_token_id
639
+
640
+ # Gemma3n downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
641
+ self.embed_tokens = Gemma3nTextScaledWordEmbedding(
642
+ config.vocab_size,
643
+ config.hidden_size,
644
+ self.padding_idx,
645
+ embed_scale=self.config.hidden_size**0.5,
646
+ )
647
+
648
+ self.norm = Gemma3nRMSNorm(
649
+ config.hidden_size,
650
+ eps=config.rms_norm_eps,
651
+ )
652
+
653
+ self.layers = make_layers(
654
+ config.num_hidden_layers,
655
+ lambda idx, prefix: Gemma3nDecoderLayer(
656
+ layer_id=idx,
657
+ config=config,
658
+ quant_config=quant_config,
659
+ prefix=prefix,
660
+ ),
661
+ prefix=add_prefix("layers", prefix),
662
+ )
663
+
664
+ # Per-layer input embeddings
665
+ self.hidden_size = config.hidden_size
666
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
667
+
668
+ self.embed_tokens_per_layer = Gemma3nTextScaledWordEmbedding(
669
+ config.vocab_size_per_layer_input,
670
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
671
+ self.padding_idx,
672
+ embed_scale=self.config.hidden_size_per_layer_input**0.5,
673
+ )
674
+
675
+ self.per_layer_model_projection = ColumnParallelLinear(
676
+ self.hidden_size,
677
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
678
+ bias=False,
679
+ quant_config=quant_config,
680
+ prefix=add_prefix("per_layer_model_projection", prefix),
681
+ )
682
+
683
+ self.per_layer_projection_norm = Gemma3nRMSNorm(
684
+ dim=config.hidden_size_per_layer_input,
685
+ eps=config.rms_norm_eps,
686
+ )
687
+
688
+ self.altup_projections = make_layers(
689
+ self.config.altup_num_inputs - 1,
690
+ lambda idx, prefix: ColumnParallelLinear(
691
+ self.hidden_size,
692
+ self.hidden_size,
693
+ bias=False,
694
+ quant_config=quant_config,
695
+ prefix=prefix,
696
+ ),
697
+ prefix=add_prefix("altup_projections", prefix),
698
+ )
699
+
700
+ self.altup_unembed_projections = make_layers(
701
+ self.config.altup_num_inputs - 1,
702
+ lambda idx, prefix: ColumnParallelLinear(
703
+ self.hidden_size,
704
+ self.hidden_size,
705
+ bias=False,
706
+ quant_config=quant_config,
707
+ prefix=prefix,
708
+ ),
709
+ prefix=add_prefix("altup_unembed_projections", prefix),
710
+ )
711
+
712
+ self.register_buffer(
713
+ "per_layer_projection_scale",
714
+ torch.tensor(self.hidden_size**-0.5),
715
+ persistent=False,
716
+ )
717
+ self.register_buffer(
718
+ "per_layer_input_scale", torch.rsqrt(torch.tensor(2.0)), persistent=False
719
+ )
720
+
721
+ self.post_init()
722
+
723
+ def get_input_embeddings(self) -> nn.Embedding:
724
+ return self.embed_tokens
725
+
726
+ def dtype(self) -> torch.dtype:
727
+ return next(self.parameters()).dtype
728
+
729
+ def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
730
+ embeddings = self.embed_tokens_per_layer(input_ids)
731
+ return embeddings.reshape(
732
+ *input_ids.shape,
733
+ self.config.num_hidden_layers,
734
+ self.hidden_size_per_layer_input,
735
+ )
736
+
737
+ def project_per_layer_inputs(
738
+ self,
739
+ inputs_embeds: torch.Tensor,
740
+ per_layer_inputs: Optional[torch.Tensor] = None,
741
+ ) -> torch.Tensor:
742
+ per_layer_projection, _ = self.per_layer_model_projection(inputs_embeds)
743
+ per_layer_projection *= self.per_layer_projection_scale.type(
744
+ inputs_embeds.dtype
745
+ )
746
+ per_layer_projection = per_layer_projection.reshape(
747
+ *inputs_embeds.shape[:-1],
748
+ self.config.num_hidden_layers,
749
+ self.hidden_size_per_layer_input,
750
+ )
751
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
752
+
753
+ if per_layer_inputs is None:
754
+ return per_layer_projection
755
+
756
+ if per_layer_projection.shape != per_layer_inputs.shape:
757
+ # per-layer inputs are sometimes padded with zeros, slice the relevant embeddings
758
+ per_layer_inputs = per_layer_inputs[..., : self.config.num_hidden_layers, :]
759
+
760
+ return (
761
+ per_layer_projection + per_layer_inputs
762
+ ) * self.per_layer_input_scale.type(inputs_embeds.dtype)
763
+
764
+ def forward(
765
+ self,
766
+ input_ids: torch.Tensor,
767
+ positions: torch.Tensor,
768
+ forward_batch: ForwardBatch,
769
+ input_embeds: torch.Tensor = None,
770
+ per_layer_inputs: Optional[torch.Tensor] = None,
771
+ **kwargs,
772
+ ) -> torch.Tensor:
773
+ if (input_ids is None) ^ (input_embeds is not None):
774
+ raise ValueError(
775
+ "You must specify exactly one of input_ids or inputs_embeds"
776
+ )
777
+
778
+ if input_ids is not None:
779
+ input_embeds = self.embed_tokens(input_ids)
780
+ per_layer_inputs = self.get_per_layer_inputs(input_ids)
781
+
782
+ per_layer_inputs = self.project_per_layer_inputs(input_embeds, per_layer_inputs)
783
+
784
+ if positions.dim() == 1:
785
+ positions = positions.unsqueeze(0)
786
+
787
+ # Expand hidden_states to support per-layer inputs
788
+ target_magnitude = torch.mean(input_embeds**2, dim=-1, keepdim=True) ** 0.5
789
+ epsilon_tensor = torch.tensor(torch.finfo(input_embeds.dtype).min)
790
+
791
+ # embed positions
792
+ hidden_states_0 = input_embeds
793
+ temp_hidden_states = [hidden_states_0]
794
+
795
+ for i in range(1, self.config.altup_num_inputs):
796
+ altup_proj, _ = self.altup_projections[i - 1](hidden_states_0)
797
+ current_hidden_state = altup_proj.type(hidden_states_0.dtype)
798
+ new_magnitude = (
799
+ torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
800
+ )
801
+ current_hidden_state = current_hidden_state * (
802
+ target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
803
+ )
804
+ temp_hidden_states.append(current_hidden_state)
805
+
806
+ hidden_states = torch.stack(
807
+ temp_hidden_states, dim=0
808
+ ) # [num_altup_inputs, n_tokens, hidden_size]
809
+
810
+ for layer_idx, layer in enumerate(self.layers):
811
+ per_layer_input = per_layer_inputs[:, layer_idx, :]
812
+ hidden_states = layer(
813
+ positions=positions,
814
+ per_layer_input=per_layer_input,
815
+ hidden_states=hidden_states,
816
+ forward_batch=forward_batch,
817
+ **kwargs,
818
+ )
819
+
820
+ # Per-layer inputs to single output
821
+ target_magnitude = (
822
+ torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
823
+ )
824
+
825
+ temp_hidden_states = [hidden_states[0]]
826
+
827
+ for i in range(1, self.config.altup_num_inputs):
828
+ # altup_unembed_projections adapted from jax.numpy.einsum("btp,pd->btd", ...)
829
+ altup_unemb_proj, _ = self.altup_unembed_projections[i - 1](
830
+ hidden_states[i]
831
+ )
832
+ current_hidden_state = altup_unemb_proj.type(hidden_states_0.dtype)
833
+ new_magnitude = (
834
+ torch.mean(current_hidden_state**2, dim=-1, keepdim=True) ** 0.5
835
+ )
836
+ current_hidden_state = current_hidden_state * (
837
+ target_magnitude / torch.maximum(new_magnitude, epsilon_tensor)
838
+ )
839
+ temp_hidden_states.append(current_hidden_state)
840
+
841
+ hidden_states = torch.stack(temp_hidden_states)
842
+ hidden_states = torch.mean(hidden_states, dim=0)
843
+ hidden_states = self.norm(hidden_states)
844
+
845
+ return hidden_states
846
+
847
+
848
+ class Gemma3nForCausalLM(PreTrainedModel):
849
+ config_class = Gemma3nTextConfig
850
+
851
+ _tied_weights_keys = ["lm_head.weight"]
852
+ _tp_plan = {"lm_head": "colwise_rep"}
853
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
854
+ config_class = Gemma3nTextConfig
855
+ base_model_prefix = "language_model"
856
+
857
+ # BitandBytes specific attributes
858
+ default_bitsandbytes_target_modules = [
859
+ ".gate_proj.",
860
+ ".down_proj.",
861
+ ".up_proj.",
862
+ ".q_proj.",
863
+ ".k_proj.",
864
+ ".v_proj.",
865
+ ".o_proj.",
866
+ ]
867
+ bitsandbytes_stacked_params_mapping = {
868
+ ".q_proj": (".qkv_proj", 0),
869
+ ".k_proj": (".qkv_proj", 1),
870
+ ".v_proj": (".qkv_proj", 2),
871
+ ".gate_proj": (".gate_up_proj", 0),
872
+ ".up_proj": (".gate_up_proj", 1),
873
+ }
874
+
875
+ packed_modules_mapping = {
876
+ ".qkv_proj": [
877
+ ".q_proj",
878
+ ".k_proj",
879
+ ".v_proj",
880
+ ],
881
+ ".gate_up_proj": [
882
+ ".gate_proj",
883
+ ".up_proj",
884
+ ],
885
+ }
886
+
887
+ # LoRA specific attributes
888
+ supported_lora_modules = [
889
+ ".qkv_proj",
890
+ ".o_proj",
891
+ ".gate_up_proj",
892
+ ".down_proj",
893
+ ]
894
+ # Gemma does not apply LoRA to the embedding layer
895
+ embedding_modules = {}
896
+ embedding_padding_modules = []
897
+ supports_lora = True
898
+
899
+ def __init__(
900
+ self,
901
+ config: Gemma3nTextConfig,
902
+ quant_config: Optional[QuantizationConfig] = None,
903
+ prefix: str = "",
904
+ ) -> None:
905
+ super().__init__(config=config)
906
+ self.config = config
907
+ self.quant_config = quant_config
908
+ self.model = Gemma3nTextModel(
909
+ config=config,
910
+ quant_config=quant_config,
911
+ prefix=add_prefix("model", prefix),
912
+ )
913
+ self.logits_processor = LogitsProcessor(config)
914
+
915
+ if self.config.tie_word_embeddings:
916
+ self.lm_head = self.model.embed_tokens
917
+ else:
918
+ self.lm_head = ParallelLMHead(
919
+ config.vocab_size,
920
+ config.hidden_size,
921
+ quant_config=quant_config,
922
+ prefix=add_prefix("lm_head", prefix),
923
+ )
924
+ self.post_init()
925
+
926
+ def get_input_embeddings(self) -> nn.Embedding:
927
+ return self.model.embed_tokens
928
+
929
+ def get_attention_sliding_window_size(self):
930
+ return get_attention_sliding_window_size(self.config)
931
+
932
+ def dtype(self) -> torch.dtype:
933
+ return next(self.parameters()).dtype
934
+
935
+ @torch.no_grad()
936
+ def forward(
937
+ self,
938
+ input_ids: torch.Tensor,
939
+ positions: torch.Tensor,
940
+ forward_batch: ForwardBatch,
941
+ input_embeds: torch.Tensor = None,
942
+ per_layer_inputs: Optional[torch.Tensor] = None,
943
+ **kwargs,
944
+ ) -> LogitsProcessor:
945
+ hidden_states = self.model(
946
+ input_ids,
947
+ positions,
948
+ forward_batch,
949
+ input_embeds,
950
+ per_layer_inputs,
951
+ **kwargs,
952
+ )
953
+
954
+ return self.logits_processor(
955
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
956
+ )
957
+
958
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
959
+ stacked_params_mapping = [
960
+ # (param_name, shard_name, shard_id)
961
+ (".qkv_proj", ".q_proj", "q"),
962
+ (".qkv_proj", ".k_proj", "k"),
963
+ (".qkv_proj", ".v_proj", "v"),
964
+ (".gate_up_proj", ".gate_proj", 0),
965
+ (".gate_up_proj", ".up_proj", 1),
966
+ ]
967
+ params_dict = dict(self.named_parameters())
968
+ loaded_params: Set[str] = set()
969
+
970
+ for name, loaded_weight in weights:
971
+ name = name.replace("model.language_model.", "model.")
972
+ for param_name, shard_name, shard_id in stacked_params_mapping:
973
+ if shard_name not in name:
974
+ continue
975
+ name = name.replace(shard_name, param_name)
976
+ # Skip loading extra bias for GPTQ models
977
+ if name.endswith(".bias") and name not in params_dict:
978
+ continue
979
+ if name not in params_dict:
980
+ # Skip loading weights that are not in the model
981
+ continue
982
+ param = params_dict[name]
983
+ weight_loader = param.weight_loader
984
+ weight_loader(param, loaded_weight, shard_id)
985
+ break
986
+ else:
987
+ # lm_head is not used in vllm as it is tied with embed_token
988
+ if "lm_head.weight" in name:
989
+ continue
990
+ # Skip loading extra bias for GPTQ models
991
+ if name.endswith(".bias") and name not in params_dict:
992
+ continue
993
+ # Remapping the name of FP8 kv-scale
994
+ name = maybe_remap_kv_scale_name(name, params_dict)
995
+ if name is None:
996
+ continue
997
+ if name not in params_dict:
998
+ # Skip loading weights that are not in the model
999
+ continue
1000
+
1001
+ param = params_dict[name]
1002
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1003
+ weight_loader(param, loaded_weight)
1004
+ loaded_params.add(name)
1005
+ return loaded_params
1006
+
1007
+
1008
+ EntryClass = Gemma3nForCausalLM
1009
+ AutoModel.register(Gemma3nTextConfig, Gemma3nForCausalLM, exist_ok=True)