sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,694 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ import copy
15
+ from typing import Iterable, Optional, Set, Tuple
16
+
17
+ import einops
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+ from transformers import (
22
+ ROPE_INIT_FUNCTIONS,
23
+ AutoModel,
24
+ Gemma3TextConfig,
25
+ PretrainedConfig,
26
+ PreTrainedModel,
27
+ )
28
+
29
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
+ from sglang.srt.layers.activation import GeluAndMul
31
+ from sglang.srt.layers.layernorm import Gemma3RMSNorm
32
+ from sglang.srt.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
41
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.weight_utils import (
44
+ default_weight_loader,
45
+ maybe_remap_kv_scale_name,
46
+ )
47
+ from sglang.srt.utils import add_prefix, make_layers
48
+
49
+
50
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
51
+ # SGLang assumes exclusive
52
+ def get_attention_sliding_window_size(config):
53
+ return config.sliding_window - 1
54
+
55
+
56
+ # Adapted from:
57
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
58
+ def extract_layer_index(prefix: str) -> int:
59
+ """Extract the layer index from a prefix string."""
60
+ parts = prefix.split(".")
61
+ for part in parts:
62
+ if part.startswith("layers."):
63
+ layer_str = part.split(".")[-1]
64
+ try:
65
+ return int(layer_str)
66
+ except ValueError:
67
+ continue
68
+ return -1
69
+
70
+
71
+ class Gemma3MLP(nn.Module):
72
+ def __init__(
73
+ self,
74
+ hidden_size: int,
75
+ intermediate_size: int,
76
+ hidden_activation: str,
77
+ quant_config: Optional[QuantizationConfig] = None,
78
+ prefix: str = "",
79
+ ) -> None:
80
+ super().__init__()
81
+ self.gate_up_proj = MergedColumnParallelLinear(
82
+ hidden_size,
83
+ [intermediate_size] * 2,
84
+ bias=False,
85
+ quant_config=quant_config,
86
+ prefix=add_prefix("gate_up_proj", prefix),
87
+ )
88
+ self.down_proj = RowParallelLinear(
89
+ intermediate_size,
90
+ hidden_size,
91
+ bias=False,
92
+ quant_config=quant_config,
93
+ prefix=add_prefix("down_proj", prefix),
94
+ )
95
+ if hidden_activation != "gelu_pytorch_tanh":
96
+ raise ValueError(
97
+ "Gemma3 uses `gelu_pytorch_tanh` as the hidden activation "
98
+ "function. Please set `hidden_activation` to "
99
+ "`gelu_pytorch_tanh`."
100
+ )
101
+ self.act_fn = GeluAndMul()
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ gate_up, _ = self.gate_up_proj(x)
105
+ x = self.act_fn(gate_up)
106
+ x, _ = self.down_proj(x)
107
+ return x
108
+
109
+
110
+ class Gemma3Attention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ layer_id: int,
114
+ config: Gemma3TextConfig,
115
+ max_position_embeddings: int,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ prefix: str = "",
118
+ ) -> None:
119
+ super().__init__()
120
+ self.layer_id = layer_id
121
+ self.config = config
122
+ tp_size = get_tensor_model_parallel_world_size()
123
+
124
+ self.total_num_heads = config.num_attention_heads
125
+ assert self.total_num_heads % tp_size == 0
126
+ self.num_heads = self.total_num_heads // tp_size
127
+ self.total_num_kv_heads = config.num_key_value_heads
128
+
129
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
130
+
131
+ if self.total_num_kv_heads >= tp_size:
132
+ # Number of KV heads is greater than TP size, so we partition
133
+ # the KV heads across multiple tensor parallel GPUs.
134
+ assert self.total_num_kv_heads % tp_size == 0
135
+ else:
136
+ # Number of KV heads is less than TP size, so we replicate
137
+ # the KV heads across multiple tensor parallel GPUs.
138
+ assert tp_size % self.total_num_kv_heads == 0
139
+
140
+ hidden_size = config.hidden_size
141
+
142
+ head_dim = getattr(
143
+ config, "head_dim", hidden_size // config.num_attention_heads
144
+ )
145
+ self.head_dim = head_dim
146
+
147
+ self.q_size = self.num_heads * self.head_dim
148
+
149
+ self.kv_size = self.num_kv_heads * self.head_dim
150
+ self.scaling = config.query_pre_attn_scalar**-0.5
151
+
152
+ self.qkv_proj = QKVParallelLinear(
153
+ hidden_size,
154
+ self.head_dim,
155
+ self.total_num_heads,
156
+ self.total_num_kv_heads,
157
+ bias=config.attention_bias,
158
+ quant_config=quant_config,
159
+ prefix=add_prefix("qkv_proj", prefix),
160
+ )
161
+ self.o_proj = RowParallelLinear(
162
+ self.total_num_heads * self.head_dim,
163
+ hidden_size,
164
+ bias=config.attention_bias,
165
+ quant_config=quant_config,
166
+ prefix=add_prefix("o_proj", prefix),
167
+ )
168
+
169
+ # Determine if layer uses sliding window based on pattern
170
+ self.is_sliding = bool((layer_id + 1) % config.sliding_window_pattern)
171
+
172
+ # Initialize the rotary embedding.
173
+ if self.is_sliding:
174
+ # Local attention. Override the values in config.json.
175
+ self.rope_theta = config.rope_local_base_freq
176
+ self.rope_scaling = {"rope_type": "default"}
177
+ # FIXME(mick): idk why vllm does this
178
+ # self.sliding_window = config.interleaved_sliding_window
179
+ self.sliding_window = get_attention_sliding_window_size(config)
180
+ else:
181
+ # Global attention. Use the values in config.json.
182
+ self.rope_theta = config.rope_theta
183
+ self.rope_scaling = config.rope_scaling
184
+ self.sliding_window = None
185
+
186
+ self.attn = RadixAttention(
187
+ self.num_heads,
188
+ self.head_dim,
189
+ self.scaling,
190
+ num_kv_heads=self.num_kv_heads,
191
+ layer_id=layer_id,
192
+ logit_cap=getattr(self.config, "attn_logit_softcapping", None),
193
+ # Module must also define `get_attention_sliding_window_size` to correctly initialize
194
+ # attention backend in `ForwardBatch`.
195
+ sliding_window_size=self.sliding_window,
196
+ prefix=add_prefix("attn", prefix),
197
+ )
198
+
199
+ # Gemma3 adds normalization for q and k
200
+ self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
201
+ self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps)
202
+
203
+ def naive_attn_with_masks(
204
+ self,
205
+ q: torch.Tensor,
206
+ k: torch.Tensor,
207
+ v: torch.Tensor,
208
+ out: torch.Tensor,
209
+ **kwargs,
210
+ ) -> torch.Tensor:
211
+ q = q.view(-1, self.num_heads, self.head_dim)
212
+ # Expand the key and value to handle GQA.
213
+ num_queries_per_kv = self.num_heads // self.num_kv_heads
214
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
215
+ k = k.repeat_interleave(num_queries_per_kv, dim=-2)
216
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
217
+ v = v.repeat_interleave(num_queries_per_kv, dim=-2)
218
+
219
+ if self.is_sliding:
220
+ attn_masks = kwargs["local_attn_masks"]
221
+ else:
222
+ attn_masks = kwargs["global_attn_masks"]
223
+
224
+ seq_lens = kwargs["seq_lens"]
225
+ start_idx = 0
226
+ for seq_len, attn_mask in zip(seq_lens, attn_masks):
227
+ end_idx = start_idx + seq_len
228
+ query = q[start_idx:end_idx].unsqueeze(0)
229
+ key = k[start_idx:end_idx].unsqueeze(0)
230
+ value = v[start_idx:end_idx].unsqueeze(0)
231
+
232
+ # Transpose.
233
+ query = query.transpose(1, 2)
234
+ key = key.transpose(1, 2)
235
+ value = value.transpose(1, 2)
236
+
237
+ output = F.scaled_dot_product_attention(
238
+ query,
239
+ key,
240
+ value,
241
+ attn_mask,
242
+ self.scaling,
243
+ )
244
+ output = output.transpose(1, 2).flatten(-2, -1)
245
+ out[start_idx:end_idx] = output
246
+ start_idx = end_idx
247
+ return out
248
+
249
+ def forward(
250
+ self,
251
+ hidden_states: torch.Tensor,
252
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
253
+ forward_batch: ForwardBatch,
254
+ **kwargs,
255
+ ) -> torch.Tensor:
256
+ qkv, _ = self.qkv_proj(hidden_states)
257
+ # [s, h * head_dim]
258
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
259
+
260
+ # [s, h, head_dim]
261
+ q = q.unflatten(-1, (self.num_heads, self.head_dim))
262
+ # -> [h, s, head_dim]
263
+ q = q.transpose(0, 1).unsqueeze(0)
264
+ q = self.q_norm(q)
265
+ k = k.unflatten(-1, (self.num_kv_heads, self.head_dim))
266
+ # -> [h, s, head_dim]
267
+ k = k.transpose(0, 1).unsqueeze(0)
268
+ k = self.k_norm(k)
269
+
270
+ # q, k = self.rotary_emb(positions, q, k)
271
+ cos, sin = position_embeddings
272
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
273
+
274
+ # [b, h, s, head_dim] -> [b, s, h, head_dim]
275
+ q = q.permute(0, 2, 1, 3)
276
+ k = k.permute(0, 2, 1, 3)
277
+
278
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
279
+ output, _ = self.o_proj(attn_output)
280
+ return output
281
+
282
+
283
+ class Gemma3DecoderLayer(nn.Module):
284
+ def __init__(
285
+ self,
286
+ layer_id: int,
287
+ config: PretrainedConfig,
288
+ quant_config: Optional[QuantizationConfig] = None,
289
+ prefix: str = "",
290
+ ) -> None:
291
+ super().__init__()
292
+ self.hidden_size = config.hidden_size
293
+ self.self_attn = Gemma3Attention(
294
+ layer_id=layer_id,
295
+ config=config,
296
+ max_position_embeddings=config.max_position_embeddings,
297
+ quant_config=quant_config,
298
+ prefix=add_prefix("self_attn", prefix),
299
+ )
300
+ self.hidden_size = config.hidden_size
301
+ self.mlp = Gemma3MLP(
302
+ hidden_size=self.hidden_size,
303
+ intermediate_size=config.intermediate_size,
304
+ hidden_activation=config.hidden_activation,
305
+ quant_config=quant_config,
306
+ prefix=add_prefix("mlp", prefix),
307
+ )
308
+ self.input_layernorm = Gemma3RMSNorm(
309
+ config.hidden_size, eps=config.rms_norm_eps
310
+ )
311
+ self.post_attention_layernorm = Gemma3RMSNorm(
312
+ config.hidden_size, eps=config.rms_norm_eps
313
+ )
314
+ self.pre_feedforward_layernorm = Gemma3RMSNorm(
315
+ config.hidden_size, eps=config.rms_norm_eps
316
+ )
317
+ self.post_feedforward_layernorm = Gemma3RMSNorm(
318
+ config.hidden_size, eps=config.rms_norm_eps
319
+ )
320
+ self.is_sliding = self.self_attn.is_sliding
321
+ self.layer_id = layer_id
322
+
323
+ def forward(
324
+ self,
325
+ positions: torch.Tensor,
326
+ hidden_states: torch.Tensor,
327
+ position_embeddings_global: torch.Tensor,
328
+ position_embeddings_local: torch.Tensor,
329
+ forward_batch: ForwardBatch,
330
+ **kwargs,
331
+ ) -> tuple[
332
+ torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
333
+ ]:
334
+ residual = hidden_states
335
+ hidden_states = self.input_layernorm(hidden_states)
336
+
337
+ # apply global RoPE to non-sliding layer only
338
+ if self.self_attn.is_sliding:
339
+ position_embeddings = position_embeddings_local
340
+ else:
341
+ position_embeddings = position_embeddings_global
342
+
343
+ hidden_states = self.self_attn(
344
+ positions=positions,
345
+ hidden_states=hidden_states,
346
+ position_embeddings=position_embeddings,
347
+ forward_batch=forward_batch,
348
+ **kwargs,
349
+ )
350
+ hidden_states = self.post_attention_layernorm(hidden_states)
351
+ hidden_states = residual + hidden_states
352
+
353
+ residual = hidden_states
354
+ hidden_states = self.pre_feedforward_layernorm(hidden_states)
355
+ hidden_states = self.mlp(hidden_states)
356
+ hidden_states = self.post_feedforward_layernorm(hidden_states)
357
+ hidden_states = residual + hidden_states
358
+
359
+ outputs = (hidden_states,)
360
+
361
+ return outputs
362
+
363
+
364
+ class Gemma3RotaryEmbedding(nn.Module):
365
+ def __init__(self, config: Gemma3TextConfig, device=None):
366
+ super().__init__()
367
+ # BC: "rope_type" was originally "type"
368
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
369
+ self.rope_type = config.rope_scaling.get(
370
+ "rope_type", config.rope_scaling.get("type")
371
+ )
372
+ else:
373
+ self.rope_type = "default"
374
+ self.max_seq_len_cached = config.max_position_embeddings
375
+ self.original_max_seq_len = config.max_position_embeddings
376
+
377
+ self.config = config
378
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
379
+
380
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
381
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
382
+ self.original_inv_freq = self.inv_freq
383
+
384
+ def _dynamic_frequency_update(self, position_ids, device):
385
+ """
386
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
387
+ 1 - growing beyond the cached sequence length (allow scaling)
388
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
389
+ """
390
+ seq_len = torch.max(position_ids) + 1
391
+ if seq_len > self.max_seq_len_cached: # growth
392
+ inv_freq, self.attention_scaling = self.rope_init_fn(
393
+ self.config, device, seq_len=seq_len
394
+ )
395
+ self.register_buffer(
396
+ "inv_freq", inv_freq, persistent=False
397
+ ) # TODO joao: may break with compilation
398
+ self.max_seq_len_cached = seq_len
399
+
400
+ if (
401
+ seq_len < self.original_max_seq_len
402
+ and self.max_seq_len_cached > self.original_max_seq_len
403
+ ): # reset
404
+ # This .to() is needed if the model has been moved to a device after being initialized (because
405
+ # the buffer is automatically moved, but not the original copy)
406
+ self.original_inv_freq = self.original_inv_freq.to(device)
407
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
408
+ self.max_seq_len_cached = self.original_max_seq_len
409
+
410
+ @torch.no_grad()
411
+ def forward(self, x, position_ids):
412
+ if "dynamic" in self.rope_type:
413
+ self._dynamic_frequency_update(position_ids, device=x.device)
414
+
415
+ # Core RoPE block
416
+ inv_freq_expanded = (
417
+ self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
418
+ )
419
+ position_ids_expanded = position_ids[:, None, :].float()
420
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
421
+ device_type = x.device.type
422
+ device_type = (
423
+ device_type
424
+ if isinstance(device_type, str) and device_type != "mps"
425
+ else "cpu"
426
+ )
427
+ with torch.autocast(device_type=device_type, enabled=False):
428
+ freqs = (
429
+ inv_freq_expanded.float().to(x.device) @ position_ids_expanded.float()
430
+ ).transpose(1, 2)
431
+ emb = torch.cat((freqs, freqs), dim=-1)
432
+ cos = emb.cos()
433
+ sin = emb.sin()
434
+
435
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
436
+ cos = cos * self.attention_scaling
437
+ sin = sin * self.attention_scaling
438
+
439
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
440
+
441
+
442
+ class Gemma3TextScaledWordEmbedding(nn.Embedding):
443
+ """
444
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
445
+ """
446
+
447
+ def __init__(
448
+ self,
449
+ num_embeddings: int,
450
+ embedding_dim: int,
451
+ padding_idx: int,
452
+ embed_scale: Optional[float] = 1.0,
453
+ ):
454
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
455
+ self.embed_scale = embed_scale
456
+
457
+ def forward(self, input_ids: torch.Tensor):
458
+ return super().forward(input_ids) * self.embed_scale
459
+
460
+
461
+ class Gemma3TextModel(PreTrainedModel):
462
+ def __init__(
463
+ self,
464
+ config: Gemma3TextConfig,
465
+ quant_config: Optional[QuantizationConfig] = None,
466
+ prefix: str = "",
467
+ ) -> None:
468
+ super().__init__(config=config)
469
+ self.config = config
470
+ self.quant_config = quant_config
471
+
472
+ self.padding_idx = config.pad_token_id
473
+ self.vocab_size = config.vocab_size
474
+
475
+ # Gemma3 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5. See https://github.com/huggingface/transformers/pull/29402
476
+ self.embed_tokens = Gemma3TextScaledWordEmbedding(
477
+ config.vocab_size,
478
+ config.hidden_size,
479
+ self.padding_idx,
480
+ embed_scale=self.config.hidden_size**0.5,
481
+ )
482
+
483
+ self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
484
+ self.rotary_emb = Gemma3RotaryEmbedding(config=config)
485
+ self.gradient_checkpointing = False
486
+
487
+ # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
488
+ config = copy.deepcopy(config)
489
+ config.rope_theta = config.rope_local_base_freq
490
+ config.rope_scaling = {"rope_type": "default"}
491
+ self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)
492
+
493
+ self.layers = make_layers(
494
+ config.num_hidden_layers,
495
+ lambda idx, prefix: Gemma3DecoderLayer(
496
+ layer_id=idx,
497
+ config=config,
498
+ quant_config=quant_config,
499
+ prefix=prefix,
500
+ ),
501
+ prefix=add_prefix("layers", prefix),
502
+ )
503
+ self.norm = Gemma3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
504
+ self.post_init()
505
+
506
+ def forward(
507
+ self,
508
+ input_ids: torch.Tensor,
509
+ positions: torch.Tensor,
510
+ forward_batch: ForwardBatch,
511
+ input_embeds: torch.Tensor = None,
512
+ **kwargs,
513
+ ) -> torch.Tensor:
514
+ if input_embeds is None:
515
+ hidden_states = self.embed_tokens(input_ids)
516
+ else:
517
+ hidden_states = input_embeds
518
+
519
+ if positions.dim() == 1:
520
+ positions = einops.rearrange(positions, "s -> 1 s")
521
+
522
+ position_embeddings_global = self.rotary_emb(hidden_states, positions)
523
+ position_embeddings_local = self.rotary_emb_local(hidden_states, positions)
524
+ for layer in self.layers:
525
+ layer_outputs = layer(
526
+ positions=positions,
527
+ position_embeddings_global=position_embeddings_global,
528
+ position_embeddings_local=position_embeddings_local,
529
+ hidden_states=hidden_states,
530
+ forward_batch=forward_batch,
531
+ **kwargs,
532
+ )
533
+ hidden_states = layer_outputs[0]
534
+
535
+ hidden_states = self.norm(hidden_states)
536
+
537
+ return hidden_states
538
+
539
+
540
+ class Gemma3ForCausalLM(PreTrainedModel):
541
+ config_class = Gemma3TextConfig
542
+
543
+ _tied_weights_keys = ["lm_head.weight"]
544
+ _tp_plan = {"lm_head": "colwise_rep"}
545
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
546
+ config_class = Gemma3TextConfig
547
+ base_model_prefix = "language_model"
548
+
549
+ # BitandBytes specific attributes
550
+ default_bitsandbytes_target_modules = [
551
+ ".gate_proj.",
552
+ ".down_proj.",
553
+ ".up_proj.",
554
+ ".q_proj.",
555
+ ".k_proj.",
556
+ ".v_proj.",
557
+ ".o_proj.",
558
+ ]
559
+ bitsandbytes_stacked_params_mapping = {
560
+ # shard_name, weight_name, index
561
+ "q_proj": ("qkv_proj", 0),
562
+ "k_proj": ("qkv_proj", 1),
563
+ "v_proj": ("qkv_proj", 2),
564
+ "gate_proj": ("gate_up_proj", 0),
565
+ "up_proj": ("gate_up_proj", 1),
566
+ }
567
+
568
+ packed_modules_mapping = {
569
+ "qkv_proj": [
570
+ "q_proj",
571
+ "k_proj",
572
+ "v_proj",
573
+ ],
574
+ "gate_up_proj": [
575
+ "gate_proj",
576
+ "up_proj",
577
+ ],
578
+ }
579
+
580
+ # LoRA specific attributes
581
+ supported_lora_modules = [
582
+ "qkv_proj",
583
+ "o_proj",
584
+ "gate_up_proj",
585
+ "down_proj",
586
+ ]
587
+ # Gemma does not apply LoRA to the embedding layer.
588
+ embedding_modules = {}
589
+ embedding_padding_modules = []
590
+ supports_lora = True
591
+
592
+ def __init__(
593
+ self,
594
+ config: Gemma3TextConfig,
595
+ quant_config: Optional[QuantizationConfig] = None,
596
+ prefix: str = "",
597
+ ) -> None:
598
+ super().__init__(config=config)
599
+ self.config = config
600
+ self.quant_config = quant_config
601
+ self.model = Gemma3TextModel(
602
+ config, quant_config, prefix=add_prefix("model", prefix)
603
+ )
604
+ self.logits_processor = LogitsProcessor(config)
605
+
606
+ if self.config.tie_word_embeddings:
607
+ self.lm_head = self.model.embed_tokens
608
+ else:
609
+ self.lm_head = ParallelLMHead(
610
+ config.vocab_size,
611
+ config.hidden_size,
612
+ quant_config=quant_config,
613
+ prefix=add_prefix("lm_head", prefix),
614
+ )
615
+ self.post_init()
616
+
617
+ def get_input_embeddings(self) -> nn.Embedding:
618
+ return self.model.embed_tokens
619
+
620
+ def get_attention_sliding_window_size(self):
621
+ return get_attention_sliding_window_size(self.config)
622
+
623
+ def dtype(self) -> torch.dtype:
624
+ return next(self.parameters()).dtype
625
+
626
+ @torch.no_grad()
627
+ def forward(
628
+ self,
629
+ input_ids: torch.Tensor,
630
+ positions: torch.Tensor,
631
+ forward_batch: ForwardBatch,
632
+ input_embeds: torch.Tensor = None,
633
+ **kwargs,
634
+ ) -> LogitsProcessor:
635
+ hidden_states = self.model(
636
+ input_ids, positions, forward_batch, input_embeds, **kwargs
637
+ )
638
+
639
+ return self.logits_processor(
640
+ input_ids, hidden_states, self.model.embed_tokens, forward_batch
641
+ )
642
+
643
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
644
+ stacked_params_mapping = [
645
+ # (param_name, shard_name, shard_id)
646
+ ("qkv_proj", "q_proj", "q"),
647
+ ("qkv_proj", "k_proj", "k"),
648
+ ("qkv_proj", "v_proj", "v"),
649
+ ("gate_up_proj", "gate_proj", 0),
650
+ ("gate_up_proj", "up_proj", 1),
651
+ ]
652
+ params_dict = dict(self.named_parameters())
653
+ loaded_params: Set[str] = set()
654
+ for name, loaded_weight in weights:
655
+ for param_name, shard_name, shard_id in stacked_params_mapping:
656
+ # if param_name in name:
657
+ # print(f"{param_name} is already in {name}")
658
+ if shard_name not in name:
659
+ continue
660
+ name = name.replace(shard_name, param_name)
661
+ # Skip loading extra bias for GPTQ models.
662
+ if name.endswith(".bias") and name not in params_dict:
663
+ continue
664
+ param = params_dict[name]
665
+ weight_loader = param.weight_loader
666
+ weight_loader(param, loaded_weight, shard_id)
667
+ break
668
+ else:
669
+ # lm_head is not used in vllm as it is tied with embed_token.
670
+ # To prevent errors, skip loading lm_head.weight.
671
+ if "lm_head.weight" in name:
672
+ continue
673
+ # Skip loading extra bias for GPTQ models.
674
+ if name.endswith(".bias") and name not in params_dict:
675
+ continue
676
+ # Remapping the name of FP8 kv-scale.
677
+ name = maybe_remap_kv_scale_name(name, params_dict)
678
+ if name is None:
679
+ continue
680
+
681
+ param = params_dict[name]
682
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
683
+ weight_loader(param, loaded_weight)
684
+ loaded_params.add(name)
685
+ # unloaded_params = params_dict.keys() - loaded_params
686
+ # if unloaded_params:
687
+ # logger.warning(
688
+ # "Some weights are not initialized from checkpoints: %s", unloaded_params
689
+ # )
690
+ return loaded_params
691
+
692
+
693
+ EntryClass = Gemma3ForCausalLM
694
+ AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)