sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,462 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from:
16
+ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
17
+
18
+ import logging
19
+ from functools import lru_cache
20
+ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import (
25
+ AutoModel,
26
+ BatchFeature,
27
+ Gemma3Config,
28
+ Gemma3Processor,
29
+ PreTrainedModel,
30
+ )
31
+ from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
32
+
33
+ from sglang.srt.hf_transformers_utils import get_processor
34
+ from sglang.srt.layers.layernorm import Gemma3RMSNorm
35
+ from sglang.srt.layers.logits_processor import LogitsProcessor
36
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
+ from sglang.srt.managers.mm_utils import (
38
+ MultiModalityDataPaddingPatternTokenPairs,
39
+ general_mm_embed_routine,
40
+ )
41
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
42
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
+ from sglang.srt.model_loader.weight_utils import (
44
+ default_weight_loader,
45
+ maybe_remap_kv_scale_name,
46
+ )
47
+ from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
48
+ from sglang.srt.utils import add_prefix
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ cached_get_processor = lru_cache(get_processor)
53
+
54
+
55
+ class Gemma3ImagePixelInputs(TypedDict):
56
+ pixel_values: torch.Tensor
57
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
58
+
59
+
60
+ class Gemma3MultiModalProjector(nn.Module):
61
+ """Projector for Gemma3 multimodal."""
62
+
63
+ def __init__(self, config: Gemma3Config):
64
+ super().__init__()
65
+
66
+ self.mm_input_projection_weight = nn.Parameter(
67
+ torch.zeros(
68
+ config.vision_config.hidden_size, config.text_config.hidden_size
69
+ )
70
+ )
71
+
72
+ self.mm_soft_emb_norm = Gemma3RMSNorm(
73
+ config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
74
+ )
75
+
76
+ self.patches_per_image = int(
77
+ config.vision_config.image_size // config.vision_config.patch_size
78
+ )
79
+ self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
80
+ self.kernel_size = self.patches_per_image // self.tokens_per_side
81
+ self.avg_pool = nn.AvgPool2d(
82
+ kernel_size=self.kernel_size, stride=self.kernel_size
83
+ )
84
+
85
+ def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
86
+ batch_size, seq_length, hidden_size = vision_outputs.shape
87
+
88
+ # Reshape for pooling
89
+ reshaped_vision_outputs = vision_outputs.transpose(1, 2)
90
+ reshaped_vision_outputs = reshaped_vision_outputs.reshape(
91
+ batch_size, hidden_size, self.patches_per_image, self.patches_per_image
92
+ )
93
+ reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
94
+
95
+ # Apply pooling
96
+ pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
97
+ pooled_vision_outputs = pooled_vision_outputs.flatten(2)
98
+ pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
99
+
100
+ # Apply normalization
101
+ normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
102
+
103
+ # Project to text embedding space
104
+ projected_vision_outputs = torch.matmul(
105
+ normed_vision_outputs, self.mm_input_projection_weight
106
+ )
107
+
108
+ return projected_vision_outputs.type_as(vision_outputs)
109
+
110
+
111
+ class Gemma3ForConditionalGeneration(PreTrainedModel):
112
+ config_class = Gemma3Config
113
+ """Gemma3 multimodal model for conditional generation."""
114
+
115
+ # BitandBytes specific attributes
116
+ default_bitsandbytes_target_modules = [
117
+ ".gate_proj.",
118
+ ".down_proj.",
119
+ ".up_proj.",
120
+ ".q_proj.",
121
+ ".k_proj.",
122
+ ".v_proj.",
123
+ ".o_proj.",
124
+ ]
125
+ bitsandbytes_stacked_params_mapping = {
126
+ # shard_name, weight_name, index
127
+ "q_proj": ("qkv_proj", 0),
128
+ "k_proj": ("qkv_proj", 1),
129
+ "v_proj": ("qkv_proj", 2),
130
+ "gate_proj": ("gate_up_proj", 0),
131
+ "up_proj": ("gate_up_proj", 1),
132
+ }
133
+
134
+ packed_modules_mapping = {
135
+ "qkv_proj": [
136
+ "q_proj",
137
+ "k_proj",
138
+ "v_proj",
139
+ ],
140
+ "gate_up_proj": [
141
+ "gate_proj",
142
+ "up_proj",
143
+ ],
144
+ }
145
+
146
+ # LoRA specific attributes
147
+ supported_lora_modules = [
148
+ "qkv_proj",
149
+ "o_proj",
150
+ "gate_up_proj",
151
+ "down_proj",
152
+ ]
153
+ # Gemma does not apply LoRA to the embedding layer.
154
+ embedding_modules = {}
155
+ embedding_padding_modules = []
156
+ supports_lora = True
157
+
158
+ def __init__(
159
+ self,
160
+ config: Gemma3Config,
161
+ quant_config: Optional[QuantizationConfig] = None,
162
+ prefix: str = "",
163
+ ) -> None:
164
+ super().__init__(config=config)
165
+ self.config = config
166
+ self.quant_config = quant_config
167
+ # Vision components
168
+ # TODO: replace with vision attention
169
+ # self.vision_tower = SiglipVisionModel(
170
+ # config.vision_config,
171
+ # quant_config,
172
+ # prefix=add_prefix("vision_tower", prefix),
173
+ # )
174
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
175
+ self.multi_modal_projector = Gemma3MultiModalProjector(config)
176
+ self.vocab_size = config.text_config.vocab_size
177
+
178
+ # Text model
179
+ self.language_model = Gemma3ForCausalLM(
180
+ config.text_config, quant_config, prefix=add_prefix("model", prefix)
181
+ )
182
+ if self.language_model.logits_processor.logit_scale:
183
+ logit_scale = getattr(config, "logit_scale", 1.0)
184
+ self.language_model.logits_processor.logit_scale *= logit_scale
185
+ self.post_init()
186
+
187
+ def pad_input_ids(
188
+ self, input_ids: List[int], image_inputs: MultimodalInputs
189
+ ) -> List[int]:
190
+ """Pad input IDs with image tokens."""
191
+ # Get special token IDs
192
+ im_start_id: int = image_inputs.im_start_id
193
+ im_end_id: int = image_inputs.im_end_id
194
+
195
+ media_token_pairs = [(im_start_id, im_end_id)]
196
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
197
+ ids = pattern.pad_input_tokens(input_ids, image_inputs)
198
+ return ids
199
+
200
+ def prepare_attn_masks(
201
+ self,
202
+ input_ids: torch.Tensor,
203
+ positions: torch.Tensor,
204
+ mask_dtype: torch.dtype,
205
+ **kwargs,
206
+ ) -> Dict:
207
+ """Prepare attention masks for multimodal inputs."""
208
+ kwargs["has_images"] = True
209
+
210
+ # Distinguish sequences by position id 0
211
+ start_indices = (positions == 0).cpu().nonzero()
212
+ num_seqs = len(start_indices)
213
+ seq_lens = []
214
+
215
+ for i in range(num_seqs):
216
+ start_idx = start_indices[i].item()
217
+ if i < num_seqs - 1:
218
+ end_idx = start_indices[i + 1].item()
219
+ else:
220
+ end_idx = len(input_ids)
221
+ seq_lens.append(end_idx - start_idx)
222
+
223
+ kwargs["seq_lens"] = seq_lens
224
+
225
+ # Create attention masks
226
+ global_attn_masks = []
227
+ local_attn_masks = []
228
+ sliding_window = self.config.text_config.interleaved_sliding_window
229
+
230
+ start_idx = 0
231
+ for seq_len in seq_lens:
232
+ end_idx = start_idx + seq_len
233
+ input_token_ids = input_ids[start_idx:end_idx]
234
+ start_idx = end_idx
235
+
236
+ # Create global causal mask
237
+ global_attn_mask = torch.empty(
238
+ 1,
239
+ 1,
240
+ seq_len,
241
+ seq_len,
242
+ dtype=mask_dtype,
243
+ device=input_ids.device,
244
+ )
245
+ global_attn_mask.fill_(float("-inf"))
246
+ global_attn_mask = global_attn_mask.triu(diagonal=1)
247
+
248
+ # Consider bidirectional attention between image tokens
249
+ img_mask = torch.zeros_like(global_attn_mask)
250
+ img_pos = input_token_ids == self.config.image_token_index
251
+ img_mask[:, :, :, img_pos] += 1
252
+ img_mask[:, :, img_pos, :] += 1
253
+ global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
254
+ global_attn_masks.append(global_attn_mask)
255
+
256
+ # Create local causal mask with sliding window
257
+ local_attn_mask = torch.ones_like(global_attn_mask)
258
+ local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
259
+ local_attn_mask = torch.where(
260
+ local_attn_mask == 0, global_attn_mask, float("-inf")
261
+ )
262
+ local_attn_masks.append(local_attn_mask)
263
+
264
+ kwargs["global_attn_masks"] = global_attn_masks
265
+ kwargs["local_attn_masks"] = local_attn_masks
266
+ return kwargs
267
+
268
+ def get_input_embeddings(self) -> nn.Embedding:
269
+ return self.language_model.get_input_embeddings()
270
+
271
+ def get_image_feature(self, image_input: MultimodalInputs):
272
+ """
273
+ Projects the last hidden state from the vision model into language model space.
274
+
275
+ Args:
276
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
277
+ The tensors corresponding to the input images.
278
+ Returns:
279
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
280
+ """
281
+ pixel_values = image_input.pixel_values
282
+ pixel_values = pixel_values.to("cuda")
283
+ pixel_values = pixel_values.to(dtype=self.language_model.dtype())
284
+
285
+ vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
286
+ image_features = self.multi_modal_projector(vision_outputs)
287
+ return image_features
288
+
289
+ def embed_mm_inputs(
290
+ self,
291
+ input_ids: torch.Tensor,
292
+ forward_batch: ForwardBatch,
293
+ image_input: MultimodalInputs,
294
+ ) -> torch.Tensor:
295
+ if input_ids is None:
296
+ raise ValueError("Unimplemented")
297
+ # boolean-masking image tokens
298
+ special_image_mask = torch.isin(
299
+ input_ids,
300
+ torch.tensor(image_input.pad_values, device=input_ids.device),
301
+ ).unsqueeze(-1)
302
+ num_image_tokens_in_input_ids = special_image_mask.sum()
303
+
304
+ inputs_embeds = None
305
+ if num_image_tokens_in_input_ids == 0:
306
+ inputs_embeds = self.get_input_embeddings()(input_ids)
307
+ return inputs_embeds
308
+ else:
309
+ # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
310
+ image_features = self.get_image_feature(image_input.pixel_values)
311
+
312
+ # print(f"image tokens from image embeddings: {image_features.numel()}")
313
+ num_image_tokens_in_embedding = (
314
+ image_features.shape[0] * image_features.shape[1]
315
+ )
316
+
317
+ if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
318
+ num_image = num_image_tokens_in_input_ids // image_features.shape[1]
319
+ image_features = image_features[:num_image, :]
320
+ logger.warning(
321
+ f"Number of images does not match number of special image tokens in the input text. "
322
+ f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
323
+ "tokens from image embeddings."
324
+ )
325
+
326
+ # Important: clamp after extracting original image boundaries
327
+ input_ids.clamp_(min=0, max=self.vocab_size - 1)
328
+
329
+ inputs_embeds = self.get_input_embeddings()(input_ids)
330
+
331
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
332
+ inputs_embeds.device
333
+ )
334
+
335
+ image_features = image_features.to(
336
+ inputs_embeds.device, inputs_embeds.dtype
337
+ )
338
+ inputs_embeds = inputs_embeds.masked_scatter(
339
+ special_image_mask, image_features
340
+ )
341
+
342
+ return inputs_embeds
343
+
344
+ @torch.no_grad()
345
+ def forward(
346
+ self,
347
+ input_ids: torch.LongTensor,
348
+ positions: torch.Tensor,
349
+ forward_batch: ForwardBatch,
350
+ input_embeds: torch.Tensor = None,
351
+ **kwargs: object,
352
+ ) -> LogitsProcessor:
353
+ r"""
354
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
355
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
356
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
357
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
358
+
359
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
360
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
361
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
362
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
363
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
364
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
365
+
366
+ Returns:
367
+
368
+ Example:
369
+
370
+ ```python
371
+ >>> from PIL import Image
372
+ >>> import requests
373
+ >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
374
+
375
+ >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
376
+ >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
377
+
378
+ >>> prompt = "answer en Where is the cow standing?"
379
+ >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
380
+ >>> image = Image.open(requests.get(url, stream=True).raw)
381
+
382
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
383
+
384
+ >>> # Generate
385
+ >>> generate_ids = model.generate(**inputs, max_length=30)
386
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
387
+ "answer en Where is the cow standing?\nbeach"
388
+ ```"""
389
+
390
+ # Important: position_ids in Gemma3 are 1-indexed
391
+ # This really does cost me sometime
392
+ positions += 1
393
+
394
+ # Replace image id with PAD if the image token if OOV, to avoid index-errors
395
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
396
+ special_image_mask = input_ids == self.config.image_token_index
397
+ llm_input_ids = input_ids.clone()
398
+ llm_input_ids[special_image_mask] = 0
399
+ else:
400
+ llm_input_ids = input_ids
401
+
402
+ inputs_embeds = general_mm_embed_routine(
403
+ input_ids=llm_input_ids,
404
+ forward_batch=forward_batch,
405
+ embed_tokens=self.get_input_embeddings(),
406
+ mm_data_embedding_func=self.get_image_feature,
407
+ )
408
+
409
+ outputs = self.language_model(
410
+ input_ids=None,
411
+ positions=positions,
412
+ forward_batch=forward_batch,
413
+ input_embeds=inputs_embeds,
414
+ **kwargs,
415
+ )
416
+
417
+ return outputs
418
+
419
+ def tie_weights(self):
420
+ return self.language_model.tie_weights()
421
+
422
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
423
+ """Load weights for the model."""
424
+ params_dict = dict(self.named_parameters())
425
+ loaded_params: Set[str] = set()
426
+
427
+ for name, loaded_weight in weights:
428
+ if "language_model" in name:
429
+ # Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
430
+ causal_loaded_params = Gemma3ForCausalLM.load_weights(
431
+ self, [(name, loaded_weight)]
432
+ )
433
+ loaded_params.update(causal_loaded_params)
434
+ continue
435
+ else:
436
+ # Skip lm_head.weight as it's tied with embed_tokens
437
+ if "lm_head.weight" in name:
438
+ continue
439
+
440
+ # Skip loading extra bias for GPTQ models
441
+ if name.endswith(".bias") and name not in params_dict:
442
+ continue
443
+
444
+ # Remapping the name of FP8 kv-scale
445
+ name = maybe_remap_kv_scale_name(name, params_dict)
446
+ if name is None:
447
+ continue
448
+ param = params_dict[name]
449
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
450
+ weight_loader(param, loaded_weight)
451
+ loaded_params.add(name)
452
+ unloaded_params = params_dict.keys() - loaded_params
453
+ if unloaded_params:
454
+ pass
455
+ # raise RuntimeError(
456
+ # f"Some weights are not initialized from checkpoints: {unloaded_params}")
457
+ return loaded_params
458
+
459
+
460
+ EntryClass = Gemma3ForConditionalGeneration
461
+
462
+ AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
@@ -17,7 +17,7 @@
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, Optional, Tuple
20
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
129
129
  self.head_dim = getattr(
130
130
  config, "head_dim", self.hidden_size // self.total_num_heads
131
131
  )
132
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
133
+ self.rotary_dim = int(partial_rotary_factor * self.head_dim)
132
134
  self.q_size = self.num_heads * self.head_dim
133
135
  self.kv_size = self.num_kv_heads * self.head_dim
134
136
  self.scaling = self.head_dim**-0.5
@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
154
156
 
155
157
  self.rotary_emb = get_rope(
156
158
  self.head_dim,
157
- rotary_dim=self.head_dim,
159
+ rotary_dim=self.rotary_dim,
158
160
  max_position=max_position_embeddings,
159
161
  base=rope_theta,
160
162
  rope_scaling=rope_scaling,
@@ -285,6 +287,7 @@ class LlamaModel(nn.Module):
285
287
  )
286
288
 
287
289
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+ self.layers_to_capture = []
288
291
 
289
292
  def forward(
290
293
  self,
@@ -292,13 +295,16 @@ class LlamaModel(nn.Module):
292
295
  positions: torch.Tensor,
293
296
  forward_batch: ForwardBatch,
294
297
  input_embeds: torch.Tensor = None,
295
- ) -> torch.Tensor:
298
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
296
299
  if input_embeds is None:
297
300
  hidden_states = self.embed_tokens(input_ids)
298
301
  else:
299
302
  hidden_states = input_embeds
300
303
  residual = None
304
+ aux_hidden_states = []
301
305
  for i in range(len(self.layers)):
306
+ if i in self.layers_to_capture:
307
+ aux_hidden_states.append(hidden_states + residual)
302
308
  layer = self.layers[i]
303
309
  hidden_states, residual = layer(
304
310
  positions,
@@ -307,7 +313,11 @@ class LlamaModel(nn.Module):
307
313
  residual,
308
314
  )
309
315
  hidden_states, _ = self.norm(hidden_states, residual)
310
- return hidden_states
316
+
317
+ if len(aux_hidden_states) == 0:
318
+ return hidden_states
319
+
320
+ return hidden_states, aux_hidden_states
311
321
 
312
322
  # If this function is called, it should always initialize KV cache scale
313
323
  # factors (or else raise an exception). Thus, handled exceptions should
@@ -335,7 +345,6 @@ class LlamaModel(nn.Module):
335
345
 
336
346
 
337
347
  class LlamaForCausalLM(nn.Module):
338
-
339
348
  # BitandBytes specific attributes
340
349
  default_bitsandbytes_target_modules = [
341
350
  ".gate_proj.",
@@ -391,6 +400,8 @@ class LlamaForCausalLM(nn.Module):
391
400
  (".gate_up_proj", ".up_proj", 1),
392
401
  ]
393
402
 
403
+ self.capture_aux_hidden_states = False
404
+
394
405
  @torch.no_grad()
395
406
  def forward(
396
407
  self,
@@ -400,10 +411,19 @@ class LlamaForCausalLM(nn.Module):
400
411
  input_embeds: torch.Tensor = None,
401
412
  get_embedding: bool = False,
402
413
  ) -> LogitsProcessorOutput:
403
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
414
+ aux_hidden_states = None
415
+ if self.capture_aux_hidden_states:
416
+ hidden_states, aux_hidden_states = self.model(
417
+ input_ids, positions, forward_batch, input_embeds
418
+ )
419
+ else:
420
+ hidden_states = self.model(
421
+ input_ids, positions, forward_batch, input_embeds
422
+ )
423
+
404
424
  if not get_embedding:
405
425
  return self.logits_processor(
406
- input_ids, hidden_states, self.lm_head, forward_batch
426
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
407
427
  )
408
428
  else:
409
429
  return self.pooler(hidden_states, forward_batch)
@@ -586,9 +606,29 @@ class LlamaForCausalLM(nn.Module):
586
606
  torch.cuda.empty_cache()
587
607
  torch.cuda.synchronize()
588
608
 
609
+ def get_embed(self):
610
+ return self.model.embed_tokens.weight
611
+
612
+ def set_embed(self, embed):
613
+ # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
614
+ if (
615
+ hasattr(self.config, "target_hidden_size")
616
+ and self.config.target_hidden_size != self.config.hidden_size
617
+ ):
618
+ return
619
+ del self.model.embed_tokens.weight
620
+ self.model.embed_tokens.weight = embed
621
+ torch.cuda.empty_cache()
622
+ torch.cuda.synchronize()
623
+
589
624
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
590
625
  self.model.load_kv_cache_scales(quantization_param_path)
591
626
 
627
+ def set_eagle3_layers_to_capture(self):
628
+ self.capture_aux_hidden_states = True
629
+ num_layers = self.config.num_hidden_layers
630
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
631
+
592
632
 
593
633
  class Phi3ForCausalLM(LlamaForCausalLM):
594
634
  pass
@@ -134,6 +134,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
134
134
  )
135
135
 
136
136
  self.logits_processor = LogitsProcessor(config)
137
+ self.capture_aux_hidden_states = False
137
138
 
138
139
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
139
140
  for name, loaded_weight in weights: