sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,511 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
from functools import lru_cache
|
4
|
+
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import (
|
9
|
+
Gemma3nAudioConfig,
|
10
|
+
Gemma3nConfig,
|
11
|
+
Gemma3nTextConfig,
|
12
|
+
Gemma3nVisionConfig,
|
13
|
+
PreTrainedModel,
|
14
|
+
)
|
15
|
+
from transformers.models.auto.modeling_auto import AutoModel
|
16
|
+
|
17
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
18
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
19
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
20
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
21
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
22
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
23
|
+
from sglang.srt.managers.mm_utils import (
|
24
|
+
MultiModalityDataPaddingPatternTokenPairs,
|
25
|
+
general_mm_embed_routine,
|
26
|
+
)
|
27
|
+
from sglang.srt.managers.schedule_batch import (
|
28
|
+
MultimodalDataItem,
|
29
|
+
MultimodalInputs,
|
30
|
+
flatten_nested_list,
|
31
|
+
)
|
32
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
33
|
+
from sglang.srt.model_loader.weight_utils import (
|
34
|
+
default_weight_loader,
|
35
|
+
maybe_remap_kv_scale_name,
|
36
|
+
)
|
37
|
+
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
|
38
|
+
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
|
39
|
+
from sglang.srt.utils import add_prefix
|
40
|
+
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
cached_get_processor = lru_cache(get_processor)
|
44
|
+
|
45
|
+
|
46
|
+
class Gemma3nImagePixelInputs(TypedDict):
|
47
|
+
pixel_values: torch.Tensor
|
48
|
+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
49
|
+
|
50
|
+
|
51
|
+
class Gemma3nAudioInputs(TypedDict):
|
52
|
+
input_features: torch.Tensor
|
53
|
+
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
54
|
+
input_features_mask: torch.Tensor
|
55
|
+
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
56
|
+
|
57
|
+
|
58
|
+
class Gemma3nMultimodalEmbedder(nn.Module):
|
59
|
+
"""Embeds token ids or soft tokens for multimodal content into language model space."""
|
60
|
+
|
61
|
+
def __init__(
|
62
|
+
self,
|
63
|
+
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
|
64
|
+
text_config: Gemma3nTextConfig,
|
65
|
+
quant_config: Optional[QuantizationConfig] = None,
|
66
|
+
prefix: str = "",
|
67
|
+
):
|
68
|
+
super().__init__()
|
69
|
+
|
70
|
+
self.multimodal_hidden_size = multimodal_config.hidden_size
|
71
|
+
self.eps = multimodal_config.rms_norm_eps
|
72
|
+
self.vocab_offset = multimodal_config.vocab_offset
|
73
|
+
self.vocab_size = multimodal_config.vocab_size
|
74
|
+
self.text_hidden_size = text_config.hidden_size
|
75
|
+
|
76
|
+
self.embedding = VocabParallelEmbedding(
|
77
|
+
self.vocab_size,
|
78
|
+
self.multimodal_hidden_size,
|
79
|
+
quant_config=quant_config,
|
80
|
+
prefix=add_prefix("embedding", prefix),
|
81
|
+
)
|
82
|
+
|
83
|
+
self.hard_embedding_norm = Gemma3nRMSNorm(
|
84
|
+
self.multimodal_hidden_size,
|
85
|
+
eps=self.eps,
|
86
|
+
)
|
87
|
+
|
88
|
+
self.soft_embedding_norm = Gemma3nRMSNorm(
|
89
|
+
self.multimodal_hidden_size,
|
90
|
+
eps=self.eps,
|
91
|
+
)
|
92
|
+
|
93
|
+
self.embedding_projection = RowParallelLinear(
|
94
|
+
self.multimodal_hidden_size,
|
95
|
+
self.text_hidden_size,
|
96
|
+
bias=False,
|
97
|
+
quant_config=quant_config,
|
98
|
+
prefix=add_prefix("embedding_projection", prefix),
|
99
|
+
)
|
100
|
+
|
101
|
+
self.embedding_post_projection_norm = Gemma3nRMSNorm(
|
102
|
+
self.text_hidden_size,
|
103
|
+
eps=self.eps,
|
104
|
+
with_scale=False,
|
105
|
+
)
|
106
|
+
|
107
|
+
def forward(
|
108
|
+
self,
|
109
|
+
input_ids: Optional[torch.LongTensor] = None,
|
110
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
111
|
+
) -> torch.Tensor:
|
112
|
+
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
|
116
|
+
`[vocab_offset, vocab_offset + vocab_size)`.
|
117
|
+
inputs_embeds: A torch.Tensor containing the soft tokens to embed.
|
118
|
+
|
119
|
+
Returns:
|
120
|
+
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
|
121
|
+
"""
|
122
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
123
|
+
raise ValueError(
|
124
|
+
"You must specify exactly one of input_ids or inputs_embeds"
|
125
|
+
)
|
126
|
+
|
127
|
+
if inputs_embeds is not None:
|
128
|
+
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
129
|
+
else:
|
130
|
+
# Handle out of vocab ids to prevent CUDA assertion failures
|
131
|
+
out_of_vocab_id = self.vocab_size - 1
|
132
|
+
adjusted_ids = input_ids - self.vocab_offset
|
133
|
+
adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids)
|
134
|
+
adjusted_ids = torch.where(
|
135
|
+
adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids
|
136
|
+
)
|
137
|
+
hard_emb = self.embedding(adjusted_ids)
|
138
|
+
emb_norm = self.hard_embedding_norm(hard_emb)
|
139
|
+
|
140
|
+
emb_norm_proj, _ = self.embedding_projection(emb_norm)
|
141
|
+
return self.embedding_post_projection_norm(emb_norm_proj)
|
142
|
+
|
143
|
+
|
144
|
+
class Gemma3nForConditionalGeneration(PreTrainedModel):
|
145
|
+
config_class = Gemma3nConfig
|
146
|
+
"""Gemma3n multimodal model for conditional generation."""
|
147
|
+
|
148
|
+
# BitandBytes specific attributes
|
149
|
+
default_bitsandbytes_target_modules = [
|
150
|
+
".gate_proj.",
|
151
|
+
".down_proj.",
|
152
|
+
".up_proj.",
|
153
|
+
".q_proj.",
|
154
|
+
".k_proj.",
|
155
|
+
".v_proj.",
|
156
|
+
".o_proj.",
|
157
|
+
".out_proj.",
|
158
|
+
]
|
159
|
+
bitsandbytes_stacked_params_mapping = {
|
160
|
+
"q_proj": ("qkv_proj", 0),
|
161
|
+
"k_proj": ("qkv_proj", 1),
|
162
|
+
"v_proj": ("qkv_proj", 2),
|
163
|
+
"gate_proj": ("gate_up_proj", 0),
|
164
|
+
"up_proj": ("gate_up_proj", 1),
|
165
|
+
"out_proj": ("proj", 0),
|
166
|
+
}
|
167
|
+
|
168
|
+
packed_modules_mapping = {
|
169
|
+
"qkv_proj": [
|
170
|
+
"q_proj",
|
171
|
+
"k_proj",
|
172
|
+
"v_proj",
|
173
|
+
],
|
174
|
+
"gate_up_proj": [
|
175
|
+
"gate_proj",
|
176
|
+
"up_proj",
|
177
|
+
],
|
178
|
+
}
|
179
|
+
|
180
|
+
# LoRA specific attributes
|
181
|
+
supported_lora_modules = [
|
182
|
+
"qkv_proj",
|
183
|
+
"o_proj",
|
184
|
+
"gate_up_proj",
|
185
|
+
"down_proj",
|
186
|
+
]
|
187
|
+
# Gemma does not apply LoRA to the embedding layer
|
188
|
+
embedding_modules = {}
|
189
|
+
embedding_padding_modules = []
|
190
|
+
supports_lora = True
|
191
|
+
|
192
|
+
def __init__(
|
193
|
+
self,
|
194
|
+
config: Gemma3nConfig,
|
195
|
+
quant_config: Optional[QuantizationConfig] = None,
|
196
|
+
prefix: str = "",
|
197
|
+
) -> None:
|
198
|
+
super().__init__(config=config)
|
199
|
+
self.config = config
|
200
|
+
self.quant_config = quant_config
|
201
|
+
|
202
|
+
prefix = add_prefix("model", prefix)
|
203
|
+
|
204
|
+
# Vision components
|
205
|
+
# TODO: Use sglang's vision model
|
206
|
+
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
207
|
+
|
208
|
+
self.embed_vision = Gemma3nMultimodalEmbedder(
|
209
|
+
config.vision_config,
|
210
|
+
config.text_config,
|
211
|
+
quant_config=quant_config,
|
212
|
+
prefix=add_prefix("embed_vision", prefix),
|
213
|
+
)
|
214
|
+
|
215
|
+
# Audio components
|
216
|
+
self.embed_audio = Gemma3nMultimodalEmbedder(
|
217
|
+
config.audio_config,
|
218
|
+
config.text_config,
|
219
|
+
quant_config=quant_config,
|
220
|
+
prefix=add_prefix("embed_audio", prefix),
|
221
|
+
)
|
222
|
+
|
223
|
+
self.audio_tower = Gemma3nAudioEncoder(
|
224
|
+
config.audio_config,
|
225
|
+
quant_config=quant_config,
|
226
|
+
prefix=add_prefix("audio_tower", prefix),
|
227
|
+
)
|
228
|
+
|
229
|
+
self.vocab_size = config.text_config.vocab_size
|
230
|
+
self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
|
231
|
+
|
232
|
+
# Text model
|
233
|
+
self.language_model = Gemma3nTextModel(
|
234
|
+
config.text_config,
|
235
|
+
quant_config,
|
236
|
+
prefix=add_prefix("language_model", prefix),
|
237
|
+
)
|
238
|
+
|
239
|
+
# Create logits processor for the multimodal model
|
240
|
+
self.logits_processor = LogitsProcessor(config.text_config)
|
241
|
+
|
242
|
+
self.post_init()
|
243
|
+
|
244
|
+
def pad_input_ids(
|
245
|
+
self,
|
246
|
+
input_ids: List[int],
|
247
|
+
mm_inputs: Optional[MultimodalInputs] = None,
|
248
|
+
) -> List[int]:
|
249
|
+
"""Pad input IDs with image and audio tokens."""
|
250
|
+
if mm_inputs is None:
|
251
|
+
return input_ids
|
252
|
+
|
253
|
+
# Collect available media token pairs
|
254
|
+
media_token_pairs = []
|
255
|
+
for attr_name in ["im_start_id", "audio_start_id"]:
|
256
|
+
if hasattr(mm_inputs, attr_name):
|
257
|
+
start_id = getattr(mm_inputs, attr_name)
|
258
|
+
end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
|
259
|
+
media_token_pairs.append((start_id, end_id))
|
260
|
+
|
261
|
+
# Apply padding pattern if we have media tokens
|
262
|
+
if media_token_pairs:
|
263
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
264
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
265
|
+
|
266
|
+
return input_ids
|
267
|
+
|
268
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
269
|
+
return self.language_model.get_input_embeddings()
|
270
|
+
|
271
|
+
def get_attention_sliding_window_size(self):
|
272
|
+
return self.config.text_config.sliding_window - 1
|
273
|
+
|
274
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
275
|
+
"""
|
276
|
+
Projects the last hidden state from the vision model into language model space.
|
277
|
+
|
278
|
+
Returns:
|
279
|
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
280
|
+
"""
|
281
|
+
# Process images one by one to handle flatten_batch=True constraint in vision_tower
|
282
|
+
all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
|
283
|
+
vision_outputs_list = []
|
284
|
+
|
285
|
+
for pixel_values_batch in all_pixel_values:
|
286
|
+
# Normalize input shape to [batch_size, channels, height, width]
|
287
|
+
if pixel_values_batch.dim() == 5:
|
288
|
+
pixel_values_batch = pixel_values_batch.squeeze(0)
|
289
|
+
elif pixel_values_batch.dim() == 3:
|
290
|
+
pixel_values_batch = pixel_values_batch.unsqueeze(0)
|
291
|
+
elif pixel_values_batch.dim() != 4:
|
292
|
+
raise ValueError(
|
293
|
+
f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
|
294
|
+
)
|
295
|
+
|
296
|
+
# Process each image in the batch
|
297
|
+
batch_size = pixel_values_batch.shape[0]
|
298
|
+
for i in range(batch_size):
|
299
|
+
pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
|
300
|
+
pixel_value = pixel_value.to(
|
301
|
+
device=self.vision_tower.device, dtype=self.language_model.dtype()
|
302
|
+
)
|
303
|
+
vision_outputs = self.vision_tower(
|
304
|
+
pixel_values=pixel_value, do_pooling=False, return_dict=True
|
305
|
+
).last_hidden_state
|
306
|
+
vision_outputs_list.append(vision_outputs)
|
307
|
+
|
308
|
+
# Concatenate all vision outputs
|
309
|
+
vision_outputs = torch.cat(vision_outputs_list, dim=0)
|
310
|
+
|
311
|
+
# Convert from (batch, channels, height, width) to (batch, height * width, channels)
|
312
|
+
vision_outputs = vision_outputs.reshape(
|
313
|
+
vision_outputs.shape[0],
|
314
|
+
self.config.vision_config.hidden_size,
|
315
|
+
self.config.vision_soft_tokens_per_image,
|
316
|
+
).permute(0, 2, 1)
|
317
|
+
|
318
|
+
# Normalize and embed the soft tokens into language model space
|
319
|
+
vision_outputs *= self.config.vision_config.hidden_size**0.5
|
320
|
+
return self.embed_vision(inputs_embeds=vision_outputs)
|
321
|
+
|
322
|
+
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
323
|
+
"""
|
324
|
+
Projects the last hidden state from the audio encoder into language model space.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
items: List of multimodal data items containing audio data.
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
|
331
|
+
"""
|
332
|
+
# Extract audio features and masks from items
|
333
|
+
all_input_features = flatten_nested_list(
|
334
|
+
[item.input_features for item in items]
|
335
|
+
)
|
336
|
+
all_input_features_mask = flatten_nested_list(
|
337
|
+
[~item.input_features_mask for item in items]
|
338
|
+
) # Note(Xinyuan): reverse the mask according to the HF implementation
|
339
|
+
|
340
|
+
# Process audio features one by one
|
341
|
+
audio_features_list = []
|
342
|
+
|
343
|
+
for input_features, input_features_mask in zip(
|
344
|
+
all_input_features, all_input_features_mask
|
345
|
+
):
|
346
|
+
# Ensure proper tensor format
|
347
|
+
if input_features.dim() == 2:
|
348
|
+
input_features = input_features.unsqueeze(0)
|
349
|
+
if input_features_mask.dim() == 1:
|
350
|
+
input_features_mask = input_features_mask.unsqueeze(0)
|
351
|
+
|
352
|
+
# Move to device and dtype
|
353
|
+
input_features = input_features.to(
|
354
|
+
device=next(self.audio_tower.parameters()).device,
|
355
|
+
dtype=self.language_model.dtype(),
|
356
|
+
)
|
357
|
+
input_features_mask = input_features_mask.to(device=input_features.device)
|
358
|
+
|
359
|
+
# Process through audio tower
|
360
|
+
audio_outputs, audio_mask = self.audio_tower(
|
361
|
+
input_features, input_features_mask
|
362
|
+
)
|
363
|
+
|
364
|
+
# Embed the audio outputs
|
365
|
+
audio_embeds = self.embed_audio(inputs_embeds=audio_outputs)
|
366
|
+
audio_features_list.append(audio_embeds)
|
367
|
+
|
368
|
+
# Concatenate all audio features
|
369
|
+
if audio_features_list:
|
370
|
+
audio_features = torch.cat(audio_features_list, dim=0)
|
371
|
+
|
372
|
+
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
373
|
+
# text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
|
374
|
+
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
375
|
+
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
376
|
+
# the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
|
377
|
+
audio_padding_toks = torch.tensor(
|
378
|
+
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
|
379
|
+
)
|
380
|
+
audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
|
381
|
+
audio_features = torch.where(
|
382
|
+
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
|
383
|
+
)
|
384
|
+
|
385
|
+
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
386
|
+
extra_padding_tokens = (
|
387
|
+
self.config.audio_soft_tokens_per_image - audio_seq_len
|
388
|
+
)
|
389
|
+
extra_padding_features = audio_padding_embs.expand(
|
390
|
+
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
391
|
+
)
|
392
|
+
|
393
|
+
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
394
|
+
return audio_features
|
395
|
+
else:
|
396
|
+
return torch.empty(
|
397
|
+
0,
|
398
|
+
0,
|
399
|
+
self.language_model.config.hidden_size,
|
400
|
+
device=next(self.parameters()).device,
|
401
|
+
dtype=self.language_model.dtype(),
|
402
|
+
)
|
403
|
+
|
404
|
+
def get_per_layer_inputs(
|
405
|
+
self, input_ids: torch.LongTensor
|
406
|
+
) -> Optional[torch.Tensor]:
|
407
|
+
return self.language_model.get_per_layer_inputs(input_ids)
|
408
|
+
|
409
|
+
def project_per_layer_inputs(
|
410
|
+
self,
|
411
|
+
inputs_embeds: torch.Tensor,
|
412
|
+
per_layer_inputs: Optional[torch.Tensor] = None,
|
413
|
+
) -> torch.Tensor:
|
414
|
+
return self.language_model.project_per_layer_inputs(
|
415
|
+
inputs_embeds, per_layer_inputs
|
416
|
+
)
|
417
|
+
|
418
|
+
@torch.no_grad()
|
419
|
+
def forward(
|
420
|
+
self,
|
421
|
+
input_ids: torch.LongTensor,
|
422
|
+
positions: torch.Tensor,
|
423
|
+
forward_batch: ForwardBatch,
|
424
|
+
input_embeds: torch.Tensor = None,
|
425
|
+
**kwargs: object,
|
426
|
+
) -> LogitsProcessor:
|
427
|
+
"""Forward pass for multimodal Gemma3n."""
|
428
|
+
if (input_ids is None) ^ (input_embeds is not None):
|
429
|
+
raise ValueError(
|
430
|
+
"You must specify exactly one of input_ids or inputs_embeds"
|
431
|
+
)
|
432
|
+
|
433
|
+
positions += 1
|
434
|
+
|
435
|
+
if input_ids is not None:
|
436
|
+
# Prepare per-layer inputs from inputs_ids
|
437
|
+
per_layer_inputs_mask = torch.logical_and(
|
438
|
+
input_ids >= 0, input_ids < self.vocab_size_per_layer_input
|
439
|
+
)
|
440
|
+
per_layer_inputs_tokens = torch.where(
|
441
|
+
per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
|
442
|
+
)
|
443
|
+
per_layer_inputs = self.language_model.get_per_layer_inputs(
|
444
|
+
per_layer_inputs_tokens
|
445
|
+
)
|
446
|
+
|
447
|
+
# Use general_mm_embed_routine for handling multimodal data
|
448
|
+
# This will automatically handle text, image, and audio embeddings
|
449
|
+
hidden_states = general_mm_embed_routine(
|
450
|
+
input_ids=input_ids,
|
451
|
+
forward_batch=forward_batch,
|
452
|
+
language_model=self.language_model,
|
453
|
+
image_data_embedding_func=self.get_image_feature,
|
454
|
+
audio_data_embedding_func=self.get_audio_feature,
|
455
|
+
positions=positions,
|
456
|
+
per_layer_inputs=per_layer_inputs,
|
457
|
+
)
|
458
|
+
|
459
|
+
# Process hidden states through logits processor
|
460
|
+
return self.logits_processor(
|
461
|
+
input_ids, hidden_states, self.language_model.embed_tokens, forward_batch
|
462
|
+
)
|
463
|
+
|
464
|
+
def tie_weights(self):
|
465
|
+
return self.language_model.tie_weights()
|
466
|
+
|
467
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
468
|
+
stacked_params_mapping = [
|
469
|
+
# (param_name, shard_name, shard_id)
|
470
|
+
(".qkv_proj", ".q_proj", "q"),
|
471
|
+
(".qkv_proj", ".k_proj", "k"),
|
472
|
+
(".qkv_proj", ".v_proj", "v"),
|
473
|
+
(".gate_up_proj", ".up_proj", 1),
|
474
|
+
(".gate_up_proj", ".gate_proj", 0),
|
475
|
+
]
|
476
|
+
"""Load weights for the model."""
|
477
|
+
params_dict = dict(self.named_parameters())
|
478
|
+
loaded_params: Set[str] = set()
|
479
|
+
|
480
|
+
for name, loaded_weight in weights:
|
481
|
+
name = re.sub(r"^model\.", "", name)
|
482
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
483
|
+
if weight_name not in name:
|
484
|
+
continue
|
485
|
+
name = name.replace(weight_name, param_name)
|
486
|
+
# Skip loading extra bias for GPTQ models
|
487
|
+
if name.endswith(".bias") and name not in params_dict:
|
488
|
+
continue
|
489
|
+
param = params_dict[name]
|
490
|
+
weight_loader = param.weight_loader
|
491
|
+
weight_loader(param, loaded_weight, shard_id)
|
492
|
+
break
|
493
|
+
else:
|
494
|
+
if "vision_model" in name:
|
495
|
+
# adapt to VisionAttention
|
496
|
+
name = name.replace(".self_attn.out_proj", ".self_attn.proj")
|
497
|
+
# Skip loading extra bias for GPTQ models
|
498
|
+
if name.endswith(".bias") and name not in params_dict:
|
499
|
+
continue
|
500
|
+
# Remapping the name of FP8 kv-scale
|
501
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
502
|
+
if name is None:
|
503
|
+
continue
|
504
|
+
param = params_dict[name]
|
505
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
506
|
+
weight_loader(param, loaded_weight)
|
507
|
+
loaded_params.add(name)
|
508
|
+
return loaded_params
|
509
|
+
|
510
|
+
|
511
|
+
EntryClass = Gemma3nForConditionalGeneration
|