sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,462 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from:
16
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
17
+
18
+ import logging
19
+ from functools import lru_cache
20
+ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import (
25
+ AutoModel,
26
+ BatchFeature,
27
+ Gemma3Config,
28
+ Gemma3Processor,
29
+ PreTrainedModel,
30
+ )
31
+ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
32
+
33
+ from sglang.srt.hf_transformers_utils import get_processor
34
+ from sglang.srt.layers.layernorm import Gemma3RMSNorm
35
+ from sglang.srt.layers.logits_processor import LogitsProcessor
36
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
+ from sglang.srt.managers.mm_utils import (
38
+ MultiModalityDataPaddingPatternTokenPairs,
39
+ general_mm_embed_routine,
40
+ )
41
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.weight_utils import (
44
+ default_weight_loader,
45
+ maybe_remap_kv_scale_name,
46
+ )
47
+ from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
48
+ from sglang.srt.utils import add_prefix
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ cached_get_processor = lru_cache(get_processor)
53
+
54
+
55
+ class Gemma3ImagePixelInputs(TypedDict):
56
+ pixel_values: torch.Tensor
57
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
58
+
59
+
60
+ class Gemma3MultiModalProjector(nn.Module):
61
+ """Projector for Gemma3 multimodal."""
62
+
63
+ def __init__(self, config: Gemma3Config):
64
+ super().__init__()
65
+
66
+ self.mm_input_projection_weight = nn.Parameter(
67
+ torch.zeros(
68
+ config.vision_config.hidden_size, config.text_config.hidden_size
69
+ )
70
+ )
71
+
72
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
73
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
74
+ )
75
+
76
+ self.patches_per_image = int(
77
+ config.vision_config.image_size // config.vision_config.patch_size
78
+ )
79
+ self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
80
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
81
+ self.avg_pool = nn.AvgPool2d(
82
+ kernel_size=self.kernel_size, stride=self.kernel_size
83
+ )
84
+
85
+ def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
86
+ batch_size, seq_length, hidden_size = vision_outputs.shape
87
+
88
+ # Reshape for pooling
89
+ reshaped_vision_outputs = vision_outputs.transpose(1, 2)
90
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
91
+ batch_size, hidden_size, self.patches_per_image, self.patches_per_image
92
+ )
93
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
94
+
95
+ # Apply pooling
96
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
97
+ pooled_vision_outputs = pooled_vision_outputs.flatten(2)
98
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
99
+
100
+ # Apply normalization
101
+ normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
102
+
103
+ # Project to text embedding space
104
+ projected_vision_outputs = torch.matmul(
105
+ normed_vision_outputs, self.mm_input_projection_weight
106
+ )
107
+
108
+ return projected_vision_outputs.type_as(vision_outputs)
109
+
110
+
111
+ class Gemma3ForConditionalGeneration(PreTrainedModel):
112
+ config_class = Gemma3Config
113
+ """Gemma3 multimodal model for conditional generation."""
114
+
115
+ # BitandBytes specific attributes
116
+ default_bitsandbytes_target_modules = [
117
+ ".gate_proj.",
118
+ ".down_proj.",
119
+ ".up_proj.",
120
+ ".q_proj.",
121
+ ".k_proj.",
122
+ ".v_proj.",
123
+ ".o_proj.",
124
+ ]
125
+ bitsandbytes_stacked_params_mapping = {
126
+ # shard_name, weight_name, index
127
+ "q_proj": ("qkv_proj", 0),
128
+ "k_proj": ("qkv_proj", 1),
129
+ "v_proj": ("qkv_proj", 2),
130
+ "gate_proj": ("gate_up_proj", 0),
131
+ "up_proj": ("gate_up_proj", 1),
132
+ }
133
+
134
+ packed_modules_mapping = {
135
+ "qkv_proj": [
136
+ "q_proj",
137
+ "k_proj",
138
+ "v_proj",
139
+ ],
140
+ "gate_up_proj": [
141
+ "gate_proj",
142
+ "up_proj",
143
+ ],
144
+ }
145
+
146
+ # LoRA specific attributes
147
+ supported_lora_modules = [
148
+ "qkv_proj",
149
+ "o_proj",
150
+ "gate_up_proj",
151
+ "down_proj",
152
+ ]
153
+ # Gemma does not apply LoRA to the embedding layer.
154
+ embedding_modules = {}
155
+ embedding_padding_modules = []
156
+ supports_lora = True
157
+
158
+ def __init__(
159
+ self,
160
+ config: Gemma3Config,
161
+ quant_config: Optional[QuantizationConfig] = None,
162
+ prefix: str = "",
163
+ ) -> None:
164
+ super().__init__(config=config)
165
+ self.config = config
166
+ self.quant_config = quant_config
167
+ # Vision components
168
+ # TODO: replace with vision attention
169
+ # self.vision_tower = SiglipVisionModel(
170
+ # config.vision_config,
171
+ # quant_config,
172
+ # prefix=add_prefix("vision_tower", prefix),
173
+ # )
174
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
175
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
176
+ self.vocab_size = config.text_config.vocab_size
177
+
178
+ # Text model
179
+ self.language_model = Gemma3ForCausalLM(
180
+ config.text_config, quant_config, prefix=add_prefix("model", prefix)
181
+ )
182
+ if self.language_model.logits_processor.logit_scale:
183
+ logit_scale = getattr(config, "logit_scale", 1.0)
184
+ self.language_model.logits_processor.logit_scale *= logit_scale
185
+ self.post_init()
186
+
187
+ def pad_input_ids(
188
+ self, input_ids: List[int], image_inputs: MultimodalInputs
189
+ ) -> List[int]:
190
+ """Pad input IDs with image tokens."""
191
+ # Get special token IDs
192
+ im_start_id: int = image_inputs.im_start_id
193
+ im_end_id: int = image_inputs.im_end_id
194
+
195
+ media_token_pairs = [(im_start_id, im_end_id)]
196
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
197
+ ids = pattern.pad_input_tokens(input_ids, image_inputs)
198
+ return ids
199
+
200
+ def prepare_attn_masks(
201
+ self,
202
+ input_ids: torch.Tensor,
203
+ positions: torch.Tensor,
204
+ mask_dtype: torch.dtype,
205
+ **kwargs,
206
+ ) -> Dict:
207
+ """Prepare attention masks for multimodal inputs."""
208
+ kwargs["has_images"] = True
209
+
210
+ # Distinguish sequences by position id 0
211
+ start_indices = (positions == 0).cpu().nonzero()
212
+ num_seqs = len(start_indices)
213
+ seq_lens = []
214
+
215
+ for i in range(num_seqs):
216
+ start_idx = start_indices[i].item()
217
+ if i < num_seqs - 1:
218
+ end_idx = start_indices[i + 1].item()
219
+ else:
220
+ end_idx = len(input_ids)
221
+ seq_lens.append(end_idx - start_idx)
222
+
223
+ kwargs["seq_lens"] = seq_lens
224
+
225
+ # Create attention masks
226
+ global_attn_masks = []
227
+ local_attn_masks = []
228
+ sliding_window = self.config.text_config.interleaved_sliding_window
229
+
230
+ start_idx = 0
231
+ for seq_len in seq_lens:
232
+ end_idx = start_idx + seq_len
233
+ input_token_ids = input_ids[start_idx:end_idx]
234
+ start_idx = end_idx
235
+
236
+ # Create global causal mask
237
+ global_attn_mask = torch.empty(
238
+ 1,
239
+ 1,
240
+ seq_len,
241
+ seq_len,
242
+ dtype=mask_dtype,
243
+ device=input_ids.device,
244
+ )
245
+ global_attn_mask.fill_(float("-inf"))
246
+ global_attn_mask = global_attn_mask.triu(diagonal=1)
247
+
248
+ # Consider bidirectional attention between image tokens
249
+ img_mask = torch.zeros_like(global_attn_mask)
250
+ img_pos = input_token_ids == self.config.image_token_index
251
+ img_mask[:, :, :, img_pos] += 1
252
+ img_mask[:, :, img_pos, :] += 1
253
+ global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
254
+ global_attn_masks.append(global_attn_mask)
255
+
256
+ # Create local causal mask with sliding window
257
+ local_attn_mask = torch.ones_like(global_attn_mask)
258
+ local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
259
+ local_attn_mask = torch.where(
260
+ local_attn_mask == 0, global_attn_mask, float("-inf")
261
+ )
262
+ local_attn_masks.append(local_attn_mask)
263
+
264
+ kwargs["global_attn_masks"] = global_attn_masks
265
+ kwargs["local_attn_masks"] = local_attn_masks
266
+ return kwargs
267
+
268
+ def get_input_embeddings(self) -> nn.Embedding:
269
+ return self.language_model.get_input_embeddings()
270
+
271
+ def get_image_feature(self, image_input: MultimodalInputs):
272
+ """
273
+ Projects the last hidden state from the vision model into language model space.
274
+
275
+ Args:
276
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
277
+ The tensors corresponding to the input images.
278
+ Returns:
279
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
280
+ """
281
+ pixel_values = image_input.pixel_values
282
+ pixel_values = pixel_values.to("cuda")
283
+ pixel_values = pixel_values.to(dtype=self.language_model.dtype())
284
+
285
+ vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
286
+ image_features = self.multi_modal_projector(vision_outputs)
287
+ return image_features
288
+
289
+ def embed_mm_inputs(
290
+ self,
291
+ input_ids: torch.Tensor,
292
+ forward_batch: ForwardBatch,
293
+ image_input: MultimodalInputs,
294
+ ) -> torch.Tensor:
295
+ if input_ids is None:
296
+ raise ValueError("Unimplemented")
297
+ # boolean-masking image tokens
298
+ special_image_mask = torch.isin(
299
+ input_ids,
300
+ torch.tensor(image_input.pad_values, device=input_ids.device),
301
+ ).unsqueeze(-1)
302
+ num_image_tokens_in_input_ids = special_image_mask.sum()
303
+
304
+ inputs_embeds = None
305
+ if num_image_tokens_in_input_ids == 0:
306
+ inputs_embeds = self.get_input_embeddings()(input_ids)
307
+ return inputs_embeds
308
+ else:
309
+ # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
310
+ image_features = self.get_image_feature(image_input.pixel_values)
311
+
312
+ # print(f"image tokens from image embeddings: {image_features.numel()}")
313
+ num_image_tokens_in_embedding = (
314
+ image_features.shape[0] * image_features.shape[1]
315
+ )
316
+
317
+ if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
318
+ num_image = num_image_tokens_in_input_ids // image_features.shape[1]
319
+ image_features = image_features[:num_image, :]
320
+ logger.warning(
321
+ f"Number of images does not match number of special image tokens in the input text. "
322
+ f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
323
+ "tokens from image embeddings."
324
+ )
325
+
326
+ # Important: clamp after extracting original image boundaries
327
+ input_ids.clamp_(min=0, max=self.vocab_size - 1)
328
+
329
+ inputs_embeds = self.get_input_embeddings()(input_ids)
330
+
331
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
332
+ inputs_embeds.device
333
+ )
334
+
335
+ image_features = image_features.to(
336
+ inputs_embeds.device, inputs_embeds.dtype
337
+ )
338
+ inputs_embeds = inputs_embeds.masked_scatter(
339
+ special_image_mask, image_features
340
+ )
341
+
342
+ return inputs_embeds
343
+
344
+ @torch.no_grad()
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor,
348
+ positions: torch.Tensor,
349
+ forward_batch: ForwardBatch,
350
+ input_embeds: torch.Tensor = None,
351
+ **kwargs: object,
352
+ ) -> LogitsProcessor:
353
+ r"""
354
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
355
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
356
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
357
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
358
+
359
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
360
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
361
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
362
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
363
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
364
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
365
+
366
+ Returns:
367
+
368
+ Example:
369
+
370
+ ```python
371
+ >>> from PIL import Image
372
+ >>> import requests
373
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
374
+
375
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
376
+ >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
377
+
378
+ >>> prompt = "answer en Where is the cow standing?"
379
+ >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
380
+ >>> image = Image.open(requests.get(url, stream=True).raw)
381
+
382
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
383
+
384
+ >>> # Generate
385
+ >>> generate_ids = model.generate(**inputs, max_length=30)
386
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
387
+ "answer en Where is the cow standing?\nbeach"
388
+ ```"""
389
+
390
+ # Important: position_ids in Gemma3 are 1-indexed
391
+ # This really does cost me sometime
392
+ positions += 1
393
+
394
+ # Replace image id with PAD if the image token if OOV, to avoid index-errors
395
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
396
+ special_image_mask = input_ids == self.config.image_token_index
397
+ llm_input_ids = input_ids.clone()
398
+ llm_input_ids[special_image_mask] = 0
399
+ else:
400
+ llm_input_ids = input_ids
401
+
402
+ inputs_embeds = general_mm_embed_routine(
403
+ input_ids=llm_input_ids,
404
+ forward_batch=forward_batch,
405
+ embed_tokens=self.get_input_embeddings(),
406
+ mm_data_embedding_func=self.get_image_feature,
407
+ )
408
+
409
+ outputs = self.language_model(
410
+ input_ids=None,
411
+ positions=positions,
412
+ forward_batch=forward_batch,
413
+ input_embeds=inputs_embeds,
414
+ **kwargs,
415
+ )
416
+
417
+ return outputs
418
+
419
+ def tie_weights(self):
420
+ return self.language_model.tie_weights()
421
+
422
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
423
+ """Load weights for the model."""
424
+ params_dict = dict(self.named_parameters())
425
+ loaded_params: Set[str] = set()
426
+
427
+ for name, loaded_weight in weights:
428
+ if "language_model" in name:
429
+ # Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
430
+ causal_loaded_params = Gemma3ForCausalLM.load_weights(
431
+ self, [(name, loaded_weight)]
432
+ )
433
+ loaded_params.update(causal_loaded_params)
434
+ continue
435
+ else:
436
+ # Skip lm_head.weight as it's tied with embed_tokens
437
+ if "lm_head.weight" in name:
438
+ continue
439
+
440
+ # Skip loading extra bias for GPTQ models
441
+ if name.endswith(".bias") and name not in params_dict:
442
+ continue
443
+
444
+ # Remapping the name of FP8 kv-scale
445
+ name = maybe_remap_kv_scale_name(name, params_dict)
446
+ if name is None:
447
+ continue
448
+ param = params_dict[name]
449
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
450
+ weight_loader(param, loaded_weight)
451
+ loaded_params.add(name)
452
+ unloaded_params = params_dict.keys() - loaded_params
453
+ if unloaded_params:
454
+ pass
455
+ # raise RuntimeError(
456
+ # f"Some weights are not initialized from checkpoints: {unloaded_params}")
457
+ return loaded_params
458
+
459
+
460
+ EntryClass = Gemma3ForConditionalGeneration
461
+
462
+ AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)