sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,511 @@
1
+ import logging
2
+ import re
3
+ from functools import lru_cache
4
+ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
5
+
6
+ import torch
7
+ from torch import nn
8
+ from transformers import (
9
+ Gemma3nAudioConfig,
10
+ Gemma3nConfig,
11
+ Gemma3nTextConfig,
12
+ Gemma3nVisionConfig,
13
+ PreTrainedModel,
14
+ )
15
+ from transformers.models.auto.modeling_auto import AutoModel
16
+
17
+ from sglang.srt.hf_transformers_utils import get_processor
18
+ from sglang.srt.layers.layernorm import RMSNorm
19
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
20
+ from sglang.srt.layers.logits_processor import LogitsProcessor
21
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
+ from sglang.srt.managers.mm_utils import (
24
+ MultiModalityDataPaddingPatternTokenPairs,
25
+ general_mm_embed_routine,
26
+ )
27
+ from sglang.srt.managers.schedule_batch import (
28
+ MultimodalDataItem,
29
+ MultimodalInputs,
30
+ flatten_nested_list,
31
+ )
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
33
+ from sglang.srt.model_loader.weight_utils import (
34
+ default_weight_loader,
35
+ maybe_remap_kv_scale_name,
36
+ )
37
+ from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
38
+ from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
39
+ from sglang.srt.utils import add_prefix
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+ cached_get_processor = lru_cache(get_processor)
44
+
45
+
46
+ class Gemma3nImagePixelInputs(TypedDict):
47
+ pixel_values: torch.Tensor
48
+ """Shape: `(batch_size * num_images, num_channels, height, width)`"""
49
+
50
+
51
+ class Gemma3nAudioInputs(TypedDict):
52
+ input_features: torch.Tensor
53
+ """Shape: `(batch_size * num_audio, seq_length, num_features)`"""
54
+ input_features_mask: torch.Tensor
55
+ """Shape: `(batch_size * num_audio, seq_length)`"""
56
+
57
+
58
+ class Gemma3nMultimodalEmbedder(nn.Module):
59
+ """Embeds token ids or soft tokens for multimodal content into language model space."""
60
+
61
+ def __init__(
62
+ self,
63
+ multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
64
+ text_config: Gemma3nTextConfig,
65
+ quant_config: Optional[QuantizationConfig] = None,
66
+ prefix: str = "",
67
+ ):
68
+ super().__init__()
69
+
70
+ self.multimodal_hidden_size = multimodal_config.hidden_size
71
+ self.eps = multimodal_config.rms_norm_eps
72
+ self.vocab_offset = multimodal_config.vocab_offset
73
+ self.vocab_size = multimodal_config.vocab_size
74
+ self.text_hidden_size = text_config.hidden_size
75
+
76
+ self.embedding = VocabParallelEmbedding(
77
+ self.vocab_size,
78
+ self.multimodal_hidden_size,
79
+ quant_config=quant_config,
80
+ prefix=add_prefix("embedding", prefix),
81
+ )
82
+
83
+ self.hard_embedding_norm = Gemma3nRMSNorm(
84
+ self.multimodal_hidden_size,
85
+ eps=self.eps,
86
+ )
87
+
88
+ self.soft_embedding_norm = Gemma3nRMSNorm(
89
+ self.multimodal_hidden_size,
90
+ eps=self.eps,
91
+ )
92
+
93
+ self.embedding_projection = RowParallelLinear(
94
+ self.multimodal_hidden_size,
95
+ self.text_hidden_size,
96
+ bias=False,
97
+ quant_config=quant_config,
98
+ prefix=add_prefix("embedding_projection", prefix),
99
+ )
100
+
101
+ self.embedding_post_projection_norm = Gemma3nRMSNorm(
102
+ self.text_hidden_size,
103
+ eps=self.eps,
104
+ with_scale=False,
105
+ )
106
+
107
+ def forward(
108
+ self,
109
+ input_ids: Optional[torch.LongTensor] = None,
110
+ inputs_embeds: Optional[torch.Tensor] = None,
111
+ ) -> torch.Tensor:
112
+ """Embeds token ids or soft tokens for multimodal content into language model space.
113
+
114
+ Args:
115
+ input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
116
+ `[vocab_offset, vocab_offset + vocab_size)`.
117
+ inputs_embeds: A torch.Tensor containing the soft tokens to embed.
118
+
119
+ Returns:
120
+ A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
121
+ """
122
+ if (input_ids is None) ^ (inputs_embeds is not None):
123
+ raise ValueError(
124
+ "You must specify exactly one of input_ids or inputs_embeds"
125
+ )
126
+
127
+ if inputs_embeds is not None:
128
+ emb_norm = self.soft_embedding_norm(inputs_embeds)
129
+ else:
130
+ # Handle out of vocab ids to prevent CUDA assertion failures
131
+ out_of_vocab_id = self.vocab_size - 1
132
+ adjusted_ids = input_ids - self.vocab_offset
133
+ adjusted_ids = torch.where(adjusted_ids < 0, out_of_vocab_id, adjusted_ids)
134
+ adjusted_ids = torch.where(
135
+ adjusted_ids >= self.vocab_size, out_of_vocab_id, adjusted_ids
136
+ )
137
+ hard_emb = self.embedding(adjusted_ids)
138
+ emb_norm = self.hard_embedding_norm(hard_emb)
139
+
140
+ emb_norm_proj, _ = self.embedding_projection(emb_norm)
141
+ return self.embedding_post_projection_norm(emb_norm_proj)
142
+
143
+
144
+ class Gemma3nForConditionalGeneration(PreTrainedModel):
145
+ config_class = Gemma3nConfig
146
+ """Gemma3n multimodal model for conditional generation."""
147
+
148
+ # BitandBytes specific attributes
149
+ default_bitsandbytes_target_modules = [
150
+ ".gate_proj.",
151
+ ".down_proj.",
152
+ ".up_proj.",
153
+ ".q_proj.",
154
+ ".k_proj.",
155
+ ".v_proj.",
156
+ ".o_proj.",
157
+ ".out_proj.",
158
+ ]
159
+ bitsandbytes_stacked_params_mapping = {
160
+ "q_proj": ("qkv_proj", 0),
161
+ "k_proj": ("qkv_proj", 1),
162
+ "v_proj": ("qkv_proj", 2),
163
+ "gate_proj": ("gate_up_proj", 0),
164
+ "up_proj": ("gate_up_proj", 1),
165
+ "out_proj": ("proj", 0),
166
+ }
167
+
168
+ packed_modules_mapping = {
169
+ "qkv_proj": [
170
+ "q_proj",
171
+ "k_proj",
172
+ "v_proj",
173
+ ],
174
+ "gate_up_proj": [
175
+ "gate_proj",
176
+ "up_proj",
177
+ ],
178
+ }
179
+
180
+ # LoRA specific attributes
181
+ supported_lora_modules = [
182
+ "qkv_proj",
183
+ "o_proj",
184
+ "gate_up_proj",
185
+ "down_proj",
186
+ ]
187
+ # Gemma does not apply LoRA to the embedding layer
188
+ embedding_modules = {}
189
+ embedding_padding_modules = []
190
+ supports_lora = True
191
+
192
+ def __init__(
193
+ self,
194
+ config: Gemma3nConfig,
195
+ quant_config: Optional[QuantizationConfig] = None,
196
+ prefix: str = "",
197
+ ) -> None:
198
+ super().__init__(config=config)
199
+ self.config = config
200
+ self.quant_config = quant_config
201
+
202
+ prefix = add_prefix("model", prefix)
203
+
204
+ # Vision components
205
+ # TODO: Use sglang's vision model
206
+ self.vision_tower = AutoModel.from_config(config=config.vision_config)
207
+
208
+ self.embed_vision = Gemma3nMultimodalEmbedder(
209
+ config.vision_config,
210
+ config.text_config,
211
+ quant_config=quant_config,
212
+ prefix=add_prefix("embed_vision", prefix),
213
+ )
214
+
215
+ # Audio components
216
+ self.embed_audio = Gemma3nMultimodalEmbedder(
217
+ config.audio_config,
218
+ config.text_config,
219
+ quant_config=quant_config,
220
+ prefix=add_prefix("embed_audio", prefix),
221
+ )
222
+
223
+ self.audio_tower = Gemma3nAudioEncoder(
224
+ config.audio_config,
225
+ quant_config=quant_config,
226
+ prefix=add_prefix("audio_tower", prefix),
227
+ )
228
+
229
+ self.vocab_size = config.text_config.vocab_size
230
+ self.vocab_size_per_layer_input = config.text_config.vocab_size_per_layer_input
231
+
232
+ # Text model
233
+ self.language_model = Gemma3nTextModel(
234
+ config.text_config,
235
+ quant_config,
236
+ prefix=add_prefix("language_model", prefix),
237
+ )
238
+
239
+ # Create logits processor for the multimodal model
240
+ self.logits_processor = LogitsProcessor(config.text_config)
241
+
242
+ self.post_init()
243
+
244
+ def pad_input_ids(
245
+ self,
246
+ input_ids: List[int],
247
+ mm_inputs: Optional[MultimodalInputs] = None,
248
+ ) -> List[int]:
249
+ """Pad input IDs with image and audio tokens."""
250
+ if mm_inputs is None:
251
+ return input_ids
252
+
253
+ # Collect available media token pairs
254
+ media_token_pairs = []
255
+ for attr_name in ["im_start_id", "audio_start_id"]:
256
+ if hasattr(mm_inputs, attr_name):
257
+ start_id = getattr(mm_inputs, attr_name)
258
+ end_id = getattr(mm_inputs, attr_name.replace("start", "end"))
259
+ media_token_pairs.append((start_id, end_id))
260
+
261
+ # Apply padding pattern if we have media tokens
262
+ if media_token_pairs:
263
+ pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
264
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
265
+
266
+ return input_ids
267
+
268
+ def get_input_embeddings(self) -> nn.Embedding:
269
+ return self.language_model.get_input_embeddings()
270
+
271
+ def get_attention_sliding_window_size(self):
272
+ return self.config.text_config.sliding_window - 1
273
+
274
+ def get_image_feature(self, items: List[MultimodalDataItem]):
275
+ """
276
+ Projects the last hidden state from the vision model into language model space.
277
+
278
+ Returns:
279
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
280
+ """
281
+ # Process images one by one to handle flatten_batch=True constraint in vision_tower
282
+ all_pixel_values = flatten_nested_list([item.pixel_values for item in items])
283
+ vision_outputs_list = []
284
+
285
+ for pixel_values_batch in all_pixel_values:
286
+ # Normalize input shape to [batch_size, channels, height, width]
287
+ if pixel_values_batch.dim() == 5:
288
+ pixel_values_batch = pixel_values_batch.squeeze(0)
289
+ elif pixel_values_batch.dim() == 3:
290
+ pixel_values_batch = pixel_values_batch.unsqueeze(0)
291
+ elif pixel_values_batch.dim() != 4:
292
+ raise ValueError(
293
+ f"Unexpected pixel_values shape: {pixel_values_batch.shape}"
294
+ )
295
+
296
+ # Process each image in the batch
297
+ batch_size = pixel_values_batch.shape[0]
298
+ for i in range(batch_size):
299
+ pixel_value = pixel_values_batch[i : i + 1] # Keep batch dimension as 1
300
+ pixel_value = pixel_value.to(
301
+ device=self.vision_tower.device, dtype=self.language_model.dtype()
302
+ )
303
+ vision_outputs = self.vision_tower(
304
+ pixel_values=pixel_value, do_pooling=False, return_dict=True
305
+ ).last_hidden_state
306
+ vision_outputs_list.append(vision_outputs)
307
+
308
+ # Concatenate all vision outputs
309
+ vision_outputs = torch.cat(vision_outputs_list, dim=0)
310
+
311
+ # Convert from (batch, channels, height, width) to (batch, height * width, channels)
312
+ vision_outputs = vision_outputs.reshape(
313
+ vision_outputs.shape[0],
314
+ self.config.vision_config.hidden_size,
315
+ self.config.vision_soft_tokens_per_image,
316
+ ).permute(0, 2, 1)
317
+
318
+ # Normalize and embed the soft tokens into language model space
319
+ vision_outputs *= self.config.vision_config.hidden_size**0.5
320
+ return self.embed_vision(inputs_embeds=vision_outputs)
321
+
322
+ def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
323
+ """
324
+ Projects the last hidden state from the audio encoder into language model space.
325
+
326
+ Args:
327
+ items: List of multimodal data items containing audio data.
328
+
329
+ Returns:
330
+ audio_features (`torch.Tensor`): Audio feature tensor of shape `(num_audios, audio_length, embed_dim)`).
331
+ """
332
+ # Extract audio features and masks from items
333
+ all_input_features = flatten_nested_list(
334
+ [item.input_features for item in items]
335
+ )
336
+ all_input_features_mask = flatten_nested_list(
337
+ [~item.input_features_mask for item in items]
338
+ ) # Note(Xinyuan): reverse the mask according to the HF implementation
339
+
340
+ # Process audio features one by one
341
+ audio_features_list = []
342
+
343
+ for input_features, input_features_mask in zip(
344
+ all_input_features, all_input_features_mask
345
+ ):
346
+ # Ensure proper tensor format
347
+ if input_features.dim() == 2:
348
+ input_features = input_features.unsqueeze(0)
349
+ if input_features_mask.dim() == 1:
350
+ input_features_mask = input_features_mask.unsqueeze(0)
351
+
352
+ # Move to device and dtype
353
+ input_features = input_features.to(
354
+ device=next(self.audio_tower.parameters()).device,
355
+ dtype=self.language_model.dtype(),
356
+ )
357
+ input_features_mask = input_features_mask.to(device=input_features.device)
358
+
359
+ # Process through audio tower
360
+ audio_outputs, audio_mask = self.audio_tower(
361
+ input_features, input_features_mask
362
+ )
363
+
364
+ # Embed the audio outputs
365
+ audio_embeds = self.embed_audio(inputs_embeds=audio_outputs)
366
+ audio_features_list.append(audio_embeds)
367
+
368
+ # Concatenate all audio features
369
+ if audio_features_list:
370
+ audio_features = torch.cat(audio_features_list, dim=0)
371
+
372
+ # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
373
+ # text to account for this. However, the audio preprocessing and encoder do not gurarantee they will
374
+ # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
375
+ # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
376
+ # the audio feature out to 188 soft tokens with the emebedding of the last token in the embed_audio vocab.
377
+ audio_padding_toks = torch.tensor(
378
+ [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
379
+ )
380
+ audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
381
+ audio_features = torch.where(
382
+ audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
383
+ )
384
+
385
+ audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
386
+ extra_padding_tokens = (
387
+ self.config.audio_soft_tokens_per_image - audio_seq_len
388
+ )
389
+ extra_padding_features = audio_padding_embs.expand(
390
+ audio_batch_size, extra_padding_tokens, audio_embed_dim
391
+ )
392
+
393
+ audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
394
+ return audio_features
395
+ else:
396
+ return torch.empty(
397
+ 0,
398
+ 0,
399
+ self.language_model.config.hidden_size,
400
+ device=next(self.parameters()).device,
401
+ dtype=self.language_model.dtype(),
402
+ )
403
+
404
+ def get_per_layer_inputs(
405
+ self, input_ids: torch.LongTensor
406
+ ) -> Optional[torch.Tensor]:
407
+ return self.language_model.get_per_layer_inputs(input_ids)
408
+
409
+ def project_per_layer_inputs(
410
+ self,
411
+ inputs_embeds: torch.Tensor,
412
+ per_layer_inputs: Optional[torch.Tensor] = None,
413
+ ) -> torch.Tensor:
414
+ return self.language_model.project_per_layer_inputs(
415
+ inputs_embeds, per_layer_inputs
416
+ )
417
+
418
+ @torch.no_grad()
419
+ def forward(
420
+ self,
421
+ input_ids: torch.LongTensor,
422
+ positions: torch.Tensor,
423
+ forward_batch: ForwardBatch,
424
+ input_embeds: torch.Tensor = None,
425
+ **kwargs: object,
426
+ ) -> LogitsProcessor:
427
+ """Forward pass for multimodal Gemma3n."""
428
+ if (input_ids is None) ^ (input_embeds is not None):
429
+ raise ValueError(
430
+ "You must specify exactly one of input_ids or inputs_embeds"
431
+ )
432
+
433
+ positions += 1
434
+
435
+ if input_ids is not None:
436
+ # Prepare per-layer inputs from inputs_ids
437
+ per_layer_inputs_mask = torch.logical_and(
438
+ input_ids >= 0, input_ids < self.vocab_size_per_layer_input
439
+ )
440
+ per_layer_inputs_tokens = torch.where(
441
+ per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
442
+ )
443
+ per_layer_inputs = self.language_model.get_per_layer_inputs(
444
+ per_layer_inputs_tokens
445
+ )
446
+
447
+ # Use general_mm_embed_routine for handling multimodal data
448
+ # This will automatically handle text, image, and audio embeddings
449
+ hidden_states = general_mm_embed_routine(
450
+ input_ids=input_ids,
451
+ forward_batch=forward_batch,
452
+ language_model=self.language_model,
453
+ image_data_embedding_func=self.get_image_feature,
454
+ audio_data_embedding_func=self.get_audio_feature,
455
+ positions=positions,
456
+ per_layer_inputs=per_layer_inputs,
457
+ )
458
+
459
+ # Process hidden states through logits processor
460
+ return self.logits_processor(
461
+ input_ids, hidden_states, self.language_model.embed_tokens, forward_batch
462
+ )
463
+
464
+ def tie_weights(self):
465
+ return self.language_model.tie_weights()
466
+
467
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
468
+ stacked_params_mapping = [
469
+ # (param_name, shard_name, shard_id)
470
+ (".qkv_proj", ".q_proj", "q"),
471
+ (".qkv_proj", ".k_proj", "k"),
472
+ (".qkv_proj", ".v_proj", "v"),
473
+ (".gate_up_proj", ".up_proj", 1),
474
+ (".gate_up_proj", ".gate_proj", 0),
475
+ ]
476
+ """Load weights for the model."""
477
+ params_dict = dict(self.named_parameters())
478
+ loaded_params: Set[str] = set()
479
+
480
+ for name, loaded_weight in weights:
481
+ name = re.sub(r"^model\.", "", name)
482
+ for param_name, weight_name, shard_id in stacked_params_mapping:
483
+ if weight_name not in name:
484
+ continue
485
+ name = name.replace(weight_name, param_name)
486
+ # Skip loading extra bias for GPTQ models
487
+ if name.endswith(".bias") and name not in params_dict:
488
+ continue
489
+ param = params_dict[name]
490
+ weight_loader = param.weight_loader
491
+ weight_loader(param, loaded_weight, shard_id)
492
+ break
493
+ else:
494
+ if "vision_model" in name:
495
+ # adapt to VisionAttention
496
+ name = name.replace(".self_attn.out_proj", ".self_attn.proj")
497
+ # Skip loading extra bias for GPTQ models
498
+ if name.endswith(".bias") and name not in params_dict:
499
+ continue
500
+ # Remapping the name of FP8 kv-scale
501
+ name = maybe_remap_kv_scale_name(name, params_dict)
502
+ if name is None:
503
+ continue
504
+ param = params_dict[name]
505
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
506
+ weight_loader(param, loaded_weight)
507
+ loaded_params.add(name)
508
+ return loaded_params
509
+
510
+
511
+ EntryClass = Gemma3nForConditionalGeneration