sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,684 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import copy
15
+ from typing import Iterable, Optional, Set, Tuple
16
+
17
+ import einops
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from transformers import (
22
+ ROPE_INIT_FUNCTIONS,
23
+ AutoModel,
24
+ Gemma3TextConfig,
25
+ PretrainedConfig,
26
+ PreTrainedModel,
27
+ )
28
+
29
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
+ from sglang.srt.layers.activation import GeluAndMul
31
+ from sglang.srt.layers.layernorm import Gemma3RMSNorm
32
+ from sglang.srt.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
41
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.weight_utils import (
44
+ default_weight_loader,
45
+ maybe_remap_kv_scale_name,
46
+ )
47
+ from sglang.srt.utils import add_prefix, make_layers
48
+
49
+
50
+ # Adapted from:
51
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
52
+ def extract_layer_index(prefix: str) -> int:
53
+ """Extract the layer index from a prefix string."""
54
+ parts = prefix.split(".")
55
+ for part in parts:
56
+ if part.startswith("layers."):
57
+ layer_str = part.split(".")[-1]
58
+ try:
59
+ return int(layer_str)
60
+ except ValueError:
61
+ continue
62
+ return -1
63
+
64
+
65
+ class Gemma3MLP(nn.Module):
66
+ def __init__(
67
+ self,
68
+ hidden_size: int,
69
+ intermediate_size: int,
70
+ hidden_activation: str,
71
+ quant_config: Optional[QuantizationConfig] = None,
72
+ prefix: str = "",
73
+ ) -> None:
74
+ super().__init__()
75
+ self.gate_up_proj = MergedColumnParallelLinear(
76
+ hidden_size,
77
+ [intermediate_size] * 2,
78
+ bias=False,
79
+ quant_config=quant_config,
80
+ prefix=add_prefix("gate_up_proj", prefix),
81
+ )
82
+ self.down_proj = RowParallelLinear(
83
+ intermediate_size,
84
+ hidden_size,
85
+ bias=False,
86
+ quant_config=quant_config,
87
+ prefix=add_prefix("down_proj", prefix),
88
+ )
89
+ if hidden_activation != "gelu_pytorch_tanh":
90
+ raise ValueError(
91
+ "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
92
+ "function. Please set `hidden_activation` to "
93
+ "`gelu_pytorch_tanh`."
94
+ )
95
+ self.act_fn = GeluAndMul()
96
+
97
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
98
+ gate_up, _ = self.gate_up_proj(x)
99
+ x = self.act_fn(gate_up)
100
+ x, _ = self.down_proj(x)
101
+ return x
102
+
103
+
104
+ class Gemma3Attention(nn.Module):
105
+ def __init__(
106
+ self,
107
+ layer_id: int,
108
+ config: Gemma3TextConfig,
109
+ max_position_embeddings: int,
110
+ quant_config: Optional[QuantizationConfig] = None,
111
+ prefix: str = "",
112
+ ) -> None:
113
+ super().__init__()
114
+ self.layer_id = layer_id
115
+ self.config = config
116
+ tp_size = get_tensor_model_parallel_world_size()
117
+
118
+ self.total_num_heads = config.num_attention_heads
119
+ assert self.total_num_heads % tp_size == 0
120
+ self.num_heads = self.total_num_heads // tp_size
121
+ self.total_num_kv_heads = config.num_key_value_heads
122
+
123
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
124
+
125
+ if self.total_num_kv_heads >= tp_size:
126
+ # Number of KV heads is greater than TP size, so we partition
127
+ # the KV heads across multiple tensor parallel GPUs.
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ # Number of KV heads is less than TP size, so we replicate
131
+ # the KV heads across multiple tensor parallel GPUs.
132
+ assert tp_size % self.total_num_kv_heads == 0
133
+
134
+ hidden_size = config.hidden_size
135
+
136
+ head_dim = getattr(
137
+ config, "head_dim", hidden_size // config.num_attention_heads
138
+ )
139
+ self.head_dim = head_dim
140
+
141
+ self.q_size = self.num_heads * self.head_dim
142
+
143
+ self.kv_size = self.num_kv_heads * self.head_dim
144
+ self.scaling = config.query_pre_attn_scalar**-0.5
145
+
146
+ self.qkv_proj = QKVParallelLinear(
147
+ hidden_size,
148
+ self.head_dim,
149
+ self.total_num_heads,
150
+ self.total_num_kv_heads,
151
+ bias=config.attention_bias,
152
+ quant_config=quant_config,
153
+ prefix=add_prefix("qkv_proj", prefix),
154
+ )
155
+ self.o_proj = RowParallelLinear(
156
+ self.total_num_heads * self.head_dim,
157
+ hidden_size,
158
+ bias=config.attention_bias,
159
+ quant_config=quant_config,
160
+ prefix=add_prefix("o_proj", prefix),
161
+ )
162
+
163
+ # Determine if layer uses sliding window based on pattern
164
+ self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
165
+
166
+ # Initialize the rotary embedding.
167
+ if self.is_sliding:
168
+ # Local attention. Override the values in config.json.
169
+ self.rope_theta = config.rope_local_base_freq
170
+ self.rope_scaling = {"rope_type": "default"}
171
+ # FIXME(mick): idk why vllm does this
172
+ # self.sliding_window = config.interleaved_sliding_window
173
+ self.sliding_window = config.sliding_window
174
+ else:
175
+ # Global attention. Use the values in config.json.
176
+ self.rope_theta = config.rope_theta
177
+ self.rope_scaling = config.rope_scaling
178
+ self.sliding_window = None
179
+
180
+ self.attn = RadixAttention(
181
+ self.num_heads,
182
+ self.head_dim,
183
+ self.scaling,
184
+ num_kv_heads=self.num_kv_heads,
185
+ layer_id=layer_id,
186
+ logit_cap=getattr(self.config, "attn_logit_softcapping", None),
187
+ sliding_window_size=self.sliding_window,
188
+ prefix=add_prefix("attn", prefix),
189
+ )
190
+
191
+ # Gemma3 adds normalization for q and k
192
+ self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
193
+ self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
194
+
195
+ def naive_attn_with_masks(
196
+ self,
197
+ q: torch.Tensor,
198
+ k: torch.Tensor,
199
+ v: torch.Tensor,
200
+ out: torch.Tensor,
201
+ **kwargs,
202
+ ) -> torch.Tensor:
203
+ q = q.view(-1, self.num_heads, self.head_dim)
204
+ # Expand the key and value to handle GQA.
205
+ num_queries_per_kv = self.num_heads // self.num_kv_heads
206
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
207
+ k = k.repeat_interleave(num_queries_per_kv, dim=-2)
208
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
209
+ v = v.repeat_interleave(num_queries_per_kv, dim=-2)
210
+
211
+ if self.is_sliding:
212
+ attn_masks = kwargs["local_attn_masks"]
213
+ else:
214
+ attn_masks = kwargs["global_attn_masks"]
215
+
216
+ seq_lens = kwargs["seq_lens"]
217
+ start_idx = 0
218
+ for seq_len, attn_mask in zip(seq_lens, attn_masks):
219
+ end_idx = start_idx + seq_len
220
+ query = q[start_idx:end_idx].unsqueeze(0)
221
+ key = k[start_idx:end_idx].unsqueeze(0)
222
+ value = v[start_idx:end_idx].unsqueeze(0)
223
+
224
+ # Transpose.
225
+ query = query.transpose(1, 2)
226
+ key = key.transpose(1, 2)
227
+ value = value.transpose(1, 2)
228
+
229
+ output = F.scaled_dot_product_attention(
230
+ query,
231
+ key,
232
+ value,
233
+ attn_mask,
234
+ self.scaling,
235
+ )
236
+ output = output.transpose(1, 2).flatten(-2, -1)
237
+ out[start_idx:end_idx] = output
238
+ start_idx = end_idx
239
+ return out
240
+
241
+ def forward(
242
+ self,
243
+ hidden_states: torch.Tensor,
244
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
245
+ forward_batch: ForwardBatch,
246
+ **kwargs,
247
+ ) -> torch.Tensor:
248
+ qkv, _ = self.qkv_proj(hidden_states)
249
+ # [s, h * head_dim]
250
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
251
+
252
+ # [s, h, head_dim]
253
+ q = q.unflatten(-1, (self.num_heads, self.head_dim))
254
+ # -> [h, s, head_dim]
255
+ q = q.transpose(0, 1).unsqueeze(0)
256
+ q = self.q_norm(q)
257
+ k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
258
+ # -> [h, s, head_dim]
259
+ k = k.transpose(0, 1).unsqueeze(0)
260
+ k = self.k_norm(k)
261
+
262
+ # q, k = self.rotary_emb(positions, q, k)
263
+ cos, sin = position_embeddings
264
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
265
+
266
+ # [b, h, s, head_dim] -> [b, s, h, head_dim]
267
+ q = q.permute(0, 2, 1, 3)
268
+ k = k.permute(0, 2, 1, 3)
269
+
270
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
271
+ output, _ = self.o_proj(attn_output)
272
+ return output
273
+
274
+
275
+ class Gemma3DecoderLayer(nn.Module):
276
+ def __init__(
277
+ self,
278
+ layer_id: int,
279
+ config: PretrainedConfig,
280
+ quant_config: Optional[QuantizationConfig] = None,
281
+ prefix: str = "",
282
+ ) -> None:
283
+ super().__init__()
284
+ self.hidden_size = config.hidden_size
285
+ self.self_attn = Gemma3Attention(
286
+ layer_id=layer_id,
287
+ config=config,
288
+ max_position_embeddings=config.max_position_embeddings,
289
+ quant_config=quant_config,
290
+ prefix=add_prefix("self_attn", prefix),
291
+ )
292
+ self.hidden_size = config.hidden_size
293
+ self.mlp = Gemma3MLP(
294
+ hidden_size=self.hidden_size,
295
+ intermediate_size=config.intermediate_size,
296
+ hidden_activation=config.hidden_activation,
297
+ quant_config=quant_config,
298
+ prefix=add_prefix("mlp", prefix),
299
+ )
300
+ self.input_layernorm = Gemma3RMSNorm(
301
+ config.hidden_size, eps=config.rms_norm_eps
302
+ )
303
+ self.post_attention_layernorm = Gemma3RMSNorm(
304
+ config.hidden_size, eps=config.rms_norm_eps
305
+ )
306
+ self.pre_feedforward_layernorm = Gemma3RMSNorm(
307
+ config.hidden_size, eps=config.rms_norm_eps
308
+ )
309
+ self.post_feedforward_layernorm = Gemma3RMSNorm(
310
+ config.hidden_size, eps=config.rms_norm_eps
311
+ )
312
+ self.is_sliding = self.self_attn.is_sliding
313
+ self.layer_id = layer_id
314
+
315
+ def forward(
316
+ self,
317
+ positions: torch.Tensor,
318
+ hidden_states: torch.Tensor,
319
+ position_embeddings_global: torch.Tensor,
320
+ position_embeddings_local: torch.Tensor,
321
+ forward_batch: ForwardBatch,
322
+ **kwargs,
323
+ ) -> tuple[
324
+ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
325
+ ]:
326
+ residual = hidden_states
327
+ hidden_states = self.input_layernorm(hidden_states)
328
+
329
+ # apply global RoPE to non-sliding layer only
330
+ if self.self_attn.is_sliding:
331
+ position_embeddings = position_embeddings_local
332
+ else:
333
+ position_embeddings = position_embeddings_global
334
+
335
+ hidden_states = self.self_attn(
336
+ positions=positions,
337
+ hidden_states=hidden_states,
338
+ position_embeddings=position_embeddings,
339
+ forward_batch=forward_batch,
340
+ **kwargs,
341
+ )
342
+ hidden_states = self.post_attention_layernorm(hidden_states)
343
+ hidden_states = residual + hidden_states
344
+
345
+ residual = hidden_states
346
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
347
+ hidden_states = self.mlp(hidden_states)
348
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
349
+ hidden_states = residual + hidden_states
350
+
351
+ outputs = (hidden_states,)
352
+
353
+ return outputs
354
+
355
+
356
+ class Gemma3RotaryEmbedding(nn.Module):
357
+ def __init__(self, config: Gemma3TextConfig, device=None):
358
+ super().__init__()
359
+ # BC: "rope_type" was originally "type"
360
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
361
+ self.rope_type = config.rope_scaling.get(
362
+ "rope_type", config.rope_scaling.get("type")
363
+ )
364
+ else:
365
+ self.rope_type = "default"
366
+ self.max_seq_len_cached = config.max_position_embeddings
367
+ self.original_max_seq_len = config.max_position_embeddings
368
+
369
+ self.config = config
370
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
371
+
372
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
373
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
374
+ self.original_inv_freq = self.inv_freq
375
+
376
+ def _dynamic_frequency_update(self, position_ids, device):
377
+ """
378
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
379
+ 1 - growing beyond the cached sequence length (allow scaling)
380
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
381
+ """
382
+ seq_len = torch.max(position_ids) + 1
383
+ if seq_len > self.max_seq_len_cached: # growth
384
+ inv_freq, self.attention_scaling = self.rope_init_fn(
385
+ self.config, device, seq_len=seq_len
386
+ )
387
+ self.register_buffer(
388
+ "inv_freq", inv_freq, persistent=False
389
+ ) # TODO joao: may break with compilation
390
+ self.max_seq_len_cached = seq_len
391
+
392
+ if (
393
+ seq_len < self.original_max_seq_len
394
+ and self.max_seq_len_cached > self.original_max_seq_len
395
+ ): # reset
396
+ # This .to() is needed if the model has been moved to a device after being initialized (because
397
+ # the buffer is automatically moved, but not the original copy)
398
+ self.original_inv_freq = self.original_inv_freq.to(device)
399
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
400
+ self.max_seq_len_cached = self.original_max_seq_len
401
+
402
+ @torch.no_grad()
403
+ def forward(self, x, position_ids):
404
+ if "dynamic" in self.rope_type:
405
+ self._dynamic_frequency_update(position_ids, device=x.device)
406
+
407
+ # Core RoPE block
408
+ inv_freq_expanded = (
409
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
410
+ )
411
+ position_ids_expanded = position_ids[:, None, :].float()
412
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
413
+ device_type = x.device.type
414
+ device_type = (
415
+ device_type
416
+ if isinstance(device_type, str) and device_type != "mps"
417
+ else "cpu"
418
+ )
419
+ with torch.autocast(device_type=device_type, enabled=False):
420
+ freqs = (
421
+ inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
422
+ ).transpose(1, 2)
423
+ emb = torch.cat((freqs, freqs), dim=-1)
424
+ cos = emb.cos()
425
+ sin = emb.sin()
426
+
427
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
428
+ cos = cos * self.attention_scaling
429
+ sin = sin * self.attention_scaling
430
+
431
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
432
+
433
+
434
+ class Gemma3TextScaledWordEmbedding(nn.Embedding):
435
+ """
436
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
437
+ """
438
+
439
+ def __init__(
440
+ self,
441
+ num_embeddings: int,
442
+ embedding_dim: int,
443
+ padding_idx: int,
444
+ embed_scale: Optional[float] = 1.0,
445
+ ):
446
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
447
+ self.embed_scale = embed_scale
448
+
449
+ def forward(self, input_ids: torch.Tensor):
450
+ return super().forward(input_ids) * self.embed_scale
451
+
452
+
453
+ class Gemma3TextModel(PreTrainedModel):
454
+ def __init__(
455
+ self,
456
+ config: Gemma3TextConfig,
457
+ quant_config: Optional[QuantizationConfig] = None,
458
+ prefix: str = "",
459
+ ) -> None:
460
+ super().__init__(config=config)
461
+ self.config = config
462
+ self.quant_config = quant_config
463
+
464
+ self.padding_idx = config.pad_token_id
465
+ self.vocab_size = config.vocab_size
466
+
467
+ # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
468
+ self.embed_tokens = Gemma3TextScaledWordEmbedding(
469
+ config.vocab_size,
470
+ config.hidden_size,
471
+ self.padding_idx,
472
+ embed_scale=self.config.hidden_size**0.5,
473
+ )
474
+
475
+ self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
476
+ self.rotary_emb = Gemma3RotaryEmbedding(config=config)
477
+ self.gradient_checkpointing = False
478
+
479
+ # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
480
+ config = copy.deepcopy(config)
481
+ config.rope_theta = config.rope_local_base_freq
482
+ config.rope_scaling = {"rope_type": "default"}
483
+ self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
484
+
485
+ self.layers = make_layers(
486
+ config.num_hidden_layers,
487
+ lambda idx, prefix: Gemma3DecoderLayer(
488
+ layer_id=idx,
489
+ config=config,
490
+ quant_config=quant_config,
491
+ prefix=prefix,
492
+ ),
493
+ prefix=add_prefix("layers", prefix),
494
+ )
495
+ self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
496
+ self.post_init()
497
+
498
+ def forward(
499
+ self,
500
+ input_ids: torch.Tensor,
501
+ positions: torch.Tensor,
502
+ forward_batch: ForwardBatch,
503
+ input_embeds: torch.Tensor = None,
504
+ **kwargs,
505
+ ) -> torch.Tensor:
506
+ if input_embeds is None:
507
+ hidden_states = self.embed_tokens(input_ids)
508
+ else:
509
+ hidden_states = input_embeds
510
+
511
+ if positions.dim() == 1:
512
+ positions = einops.rearrange(positions, "s -> 1 s")
513
+
514
+ position_embeddings_global = self.rotary_emb(hidden_states, positions)
515
+ position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
516
+ for layer in self.layers:
517
+ layer_outputs = layer(
518
+ positions=positions,
519
+ position_embeddings_global=position_embeddings_global,
520
+ position_embeddings_local=position_embeddings_local,
521
+ hidden_states=hidden_states,
522
+ forward_batch=forward_batch,
523
+ **kwargs,
524
+ )
525
+ hidden_states = layer_outputs[0]
526
+
527
+ hidden_states = self.norm(hidden_states)
528
+
529
+ return hidden_states
530
+
531
+
532
+ class Gemma3ForCausalLM(PreTrainedModel):
533
+ config_class = Gemma3TextConfig
534
+
535
+ _tied_weights_keys = ["lm_head.weight"]
536
+ _tp_plan = {"lm_head": "colwise_rep"}
537
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
538
+ config_class = Gemma3TextConfig
539
+ base_model_prefix = "language_model"
540
+
541
+ # BitandBytes specific attributes
542
+ default_bitsandbytes_target_modules = [
543
+ ".gate_proj.",
544
+ ".down_proj.",
545
+ ".up_proj.",
546
+ ".q_proj.",
547
+ ".k_proj.",
548
+ ".v_proj.",
549
+ ".o_proj.",
550
+ ]
551
+ bitsandbytes_stacked_params_mapping = {
552
+ # shard_name, weight_name, index
553
+ "q_proj": ("qkv_proj", 0),
554
+ "k_proj": ("qkv_proj", 1),
555
+ "v_proj": ("qkv_proj", 2),
556
+ "gate_proj": ("gate_up_proj", 0),
557
+ "up_proj": ("gate_up_proj", 1),
558
+ }
559
+
560
+ packed_modules_mapping = {
561
+ "qkv_proj": [
562
+ "q_proj",
563
+ "k_proj",
564
+ "v_proj",
565
+ ],
566
+ "gate_up_proj": [
567
+ "gate_proj",
568
+ "up_proj",
569
+ ],
570
+ }
571
+
572
+ # LoRA specific attributes
573
+ supported_lora_modules = [
574
+ "qkv_proj",
575
+ "o_proj",
576
+ "gate_up_proj",
577
+ "down_proj",
578
+ ]
579
+ # Gemma does not apply LoRA to the embedding layer.
580
+ embedding_modules = {}
581
+ embedding_padding_modules = []
582
+ supports_lora = True
583
+
584
+ def __init__(
585
+ self,
586
+ config: Gemma3TextConfig,
587
+ quant_config: Optional[QuantizationConfig] = None,
588
+ prefix: str = "",
589
+ ) -> None:
590
+ super().__init__(config=config)
591
+ self.config = config
592
+ self.quant_config = quant_config
593
+ self.model = Gemma3TextModel(
594
+ config, quant_config, prefix=add_prefix("model", prefix)
595
+ )
596
+ self.logits_processor = LogitsProcessor(config)
597
+
598
+ if self.config.tie_word_embeddings:
599
+ self.lm_head = self.model.embed_tokens
600
+ else:
601
+ self.lm_head = ParallelLMHead(
602
+ config.vocab_size,
603
+ config.hidden_size,
604
+ quant_config=quant_config,
605
+ prefix=add_prefix("lm_head", prefix),
606
+ )
607
+ self.post_init()
608
+
609
+ def get_input_embeddings(self) -> nn.Embedding:
610
+ return self.model.embed_tokens
611
+
612
+ def dtype(self) -> torch.dtype:
613
+ return next(self.parameters()).dtype
614
+
615
+ @torch.no_grad()
616
+ def forward(
617
+ self,
618
+ input_ids: torch.Tensor,
619
+ positions: torch.Tensor,
620
+ forward_batch: ForwardBatch,
621
+ input_embeds: torch.Tensor = None,
622
+ **kwargs,
623
+ ) -> LogitsProcessor:
624
+
625
+ hidden_states = self.model(
626
+ input_ids, positions, forward_batch, input_embeds, **kwargs
627
+ )
628
+
629
+ return self.logits_processor(
630
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
631
+ )
632
+
633
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
634
+ stacked_params_mapping = [
635
+ # (param_name, shard_name, shard_id)
636
+ ("qkv_proj", "q_proj", "q"),
637
+ ("qkv_proj", "k_proj", "k"),
638
+ ("qkv_proj", "v_proj", "v"),
639
+ ("gate_up_proj", "gate_proj", 0),
640
+ ("gate_up_proj", "up_proj", 1),
641
+ ]
642
+ params_dict = dict(self.named_parameters())
643
+ loaded_params: Set[str] = set()
644
+ for name, loaded_weight in weights:
645
+ for param_name, shard_name, shard_id in stacked_params_mapping:
646
+ # if param_name in name:
647
+ # print(f"{param_name} is already in {name}")
648
+ if shard_name not in name:
649
+ continue
650
+ name = name.replace(shard_name, param_name)
651
+ # Skip loading extra bias for GPTQ models.
652
+ if name.endswith(".bias") and name not in params_dict:
653
+ continue
654
+ param = params_dict[name]
655
+ weight_loader = param.weight_loader
656
+ weight_loader(param, loaded_weight, shard_id)
657
+ break
658
+ else:
659
+ # lm_head is not used in vllm as it is tied with embed_token.
660
+ # To prevent errors, skip loading lm_head.weight.
661
+ if "lm_head.weight" in name:
662
+ continue
663
+ # Skip loading extra bias for GPTQ models.
664
+ if name.endswith(".bias") and name not in params_dict:
665
+ continue
666
+ # Remapping the name of FP8 kv-scale.
667
+ name = maybe_remap_kv_scale_name(name, params_dict)
668
+ if name is None:
669
+ continue
670
+
671
+ param = params_dict[name]
672
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
673
+ weight_loader(param, loaded_weight)
674
+ loaded_params.add(name)
675
+ # unloaded_params = params_dict.keys() - loaded_params
676
+ # if unloaded_params:
677
+ # logger.warning(
678
+ # "Some weights are not initialized from checkpoints: %s", unloaded_params
679
+ # )
680
+ return loaded_params
681
+
682
+
683
+ EntryClass = Gemma3ForCausalLM
684
+ AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)