sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,462 @@
|
|
1
|
+
# Copyright 2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
# Adapted from:
|
16
|
+
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
17
|
+
|
18
|
+
import logging
|
19
|
+
from functools import lru_cache
|
20
|
+
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
from transformers import (
|
25
|
+
AutoModel,
|
26
|
+
BatchFeature,
|
27
|
+
Gemma3Config,
|
28
|
+
Gemma3Processor,
|
29
|
+
PreTrainedModel,
|
30
|
+
)
|
31
|
+
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
32
|
+
|
33
|
+
from sglang.srt.hf_transformers_utils import get_processor
|
34
|
+
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
35
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
|
+
from sglang.srt.managers.mm_utils import (
|
38
|
+
MultiModalityDataPaddingPatternTokenPairs,
|
39
|
+
general_mm_embed_routine,
|
40
|
+
)
|
41
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
+
from sglang.srt.model_loader.weight_utils import (
|
44
|
+
default_weight_loader,
|
45
|
+
maybe_remap_kv_scale_name,
|
46
|
+
)
|
47
|
+
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
|
48
|
+
from sglang.srt.utils import add_prefix
|
49
|
+
|
50
|
+
logger = logging.getLogger(__name__)
|
51
|
+
|
52
|
+
cached_get_processor = lru_cache(get_processor)
|
53
|
+
|
54
|
+
|
55
|
+
class Gemma3ImagePixelInputs(TypedDict):
|
56
|
+
pixel_values: torch.Tensor
|
57
|
+
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
58
|
+
|
59
|
+
|
60
|
+
class Gemma3MultiModalProjector(nn.Module):
|
61
|
+
"""Projector for Gemma3 multimodal."""
|
62
|
+
|
63
|
+
def __init__(self, config: Gemma3Config):
|
64
|
+
super().__init__()
|
65
|
+
|
66
|
+
self.mm_input_projection_weight = nn.Parameter(
|
67
|
+
torch.zeros(
|
68
|
+
config.vision_config.hidden_size, config.text_config.hidden_size
|
69
|
+
)
|
70
|
+
)
|
71
|
+
|
72
|
+
self.mm_soft_emb_norm = Gemma3RMSNorm(
|
73
|
+
config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps
|
74
|
+
)
|
75
|
+
|
76
|
+
self.patches_per_image = int(
|
77
|
+
config.vision_config.image_size // config.vision_config.patch_size
|
78
|
+
)
|
79
|
+
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
|
80
|
+
self.kernel_size = self.patches_per_image // self.tokens_per_side
|
81
|
+
self.avg_pool = nn.AvgPool2d(
|
82
|
+
kernel_size=self.kernel_size, stride=self.kernel_size
|
83
|
+
)
|
84
|
+
|
85
|
+
def forward(self, vision_outputs: torch.Tensor) -> torch.Tensor:
|
86
|
+
batch_size, seq_length, hidden_size = vision_outputs.shape
|
87
|
+
|
88
|
+
# Reshape for pooling
|
89
|
+
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
|
90
|
+
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
|
91
|
+
batch_size, hidden_size, self.patches_per_image, self.patches_per_image
|
92
|
+
)
|
93
|
+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
|
94
|
+
|
95
|
+
# Apply pooling
|
96
|
+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
|
97
|
+
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
|
98
|
+
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
|
99
|
+
|
100
|
+
# Apply normalization
|
101
|
+
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
|
102
|
+
|
103
|
+
# Project to text embedding space
|
104
|
+
projected_vision_outputs = torch.matmul(
|
105
|
+
normed_vision_outputs, self.mm_input_projection_weight
|
106
|
+
)
|
107
|
+
|
108
|
+
return projected_vision_outputs.type_as(vision_outputs)
|
109
|
+
|
110
|
+
|
111
|
+
class Gemma3ForConditionalGeneration(PreTrainedModel):
|
112
|
+
config_class = Gemma3Config
|
113
|
+
"""Gemma3 multimodal model for conditional generation."""
|
114
|
+
|
115
|
+
# BitandBytes specific attributes
|
116
|
+
default_bitsandbytes_target_modules = [
|
117
|
+
".gate_proj.",
|
118
|
+
".down_proj.",
|
119
|
+
".up_proj.",
|
120
|
+
".q_proj.",
|
121
|
+
".k_proj.",
|
122
|
+
".v_proj.",
|
123
|
+
".o_proj.",
|
124
|
+
]
|
125
|
+
bitsandbytes_stacked_params_mapping = {
|
126
|
+
# shard_name, weight_name, index
|
127
|
+
"q_proj": ("qkv_proj", 0),
|
128
|
+
"k_proj": ("qkv_proj", 1),
|
129
|
+
"v_proj": ("qkv_proj", 2),
|
130
|
+
"gate_proj": ("gate_up_proj", 0),
|
131
|
+
"up_proj": ("gate_up_proj", 1),
|
132
|
+
}
|
133
|
+
|
134
|
+
packed_modules_mapping = {
|
135
|
+
"qkv_proj": [
|
136
|
+
"q_proj",
|
137
|
+
"k_proj",
|
138
|
+
"v_proj",
|
139
|
+
],
|
140
|
+
"gate_up_proj": [
|
141
|
+
"gate_proj",
|
142
|
+
"up_proj",
|
143
|
+
],
|
144
|
+
}
|
145
|
+
|
146
|
+
# LoRA specific attributes
|
147
|
+
supported_lora_modules = [
|
148
|
+
"qkv_proj",
|
149
|
+
"o_proj",
|
150
|
+
"gate_up_proj",
|
151
|
+
"down_proj",
|
152
|
+
]
|
153
|
+
# Gemma does not apply LoRA to the embedding layer.
|
154
|
+
embedding_modules = {}
|
155
|
+
embedding_padding_modules = []
|
156
|
+
supports_lora = True
|
157
|
+
|
158
|
+
def __init__(
|
159
|
+
self,
|
160
|
+
config: Gemma3Config,
|
161
|
+
quant_config: Optional[QuantizationConfig] = None,
|
162
|
+
prefix: str = "",
|
163
|
+
) -> None:
|
164
|
+
super().__init__(config=config)
|
165
|
+
self.config = config
|
166
|
+
self.quant_config = quant_config
|
167
|
+
# Vision components
|
168
|
+
# TODO: replace with vision attention
|
169
|
+
# self.vision_tower = SiglipVisionModel(
|
170
|
+
# config.vision_config,
|
171
|
+
# quant_config,
|
172
|
+
# prefix=add_prefix("vision_tower", prefix),
|
173
|
+
# )
|
174
|
+
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
175
|
+
self.multi_modal_projector = Gemma3MultiModalProjector(config)
|
176
|
+
self.vocab_size = config.text_config.vocab_size
|
177
|
+
|
178
|
+
# Text model
|
179
|
+
self.language_model = Gemma3ForCausalLM(
|
180
|
+
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
181
|
+
)
|
182
|
+
if self.language_model.logits_processor.logit_scale:
|
183
|
+
logit_scale = getattr(config, "logit_scale", 1.0)
|
184
|
+
self.language_model.logits_processor.logit_scale *= logit_scale
|
185
|
+
self.post_init()
|
186
|
+
|
187
|
+
def pad_input_ids(
|
188
|
+
self, input_ids: List[int], image_inputs: MultimodalInputs
|
189
|
+
) -> List[int]:
|
190
|
+
"""Pad input IDs with image tokens."""
|
191
|
+
# Get special token IDs
|
192
|
+
im_start_id: int = image_inputs.im_start_id
|
193
|
+
im_end_id: int = image_inputs.im_end_id
|
194
|
+
|
195
|
+
media_token_pairs = [(im_start_id, im_end_id)]
|
196
|
+
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
197
|
+
ids = pattern.pad_input_tokens(input_ids, image_inputs)
|
198
|
+
return ids
|
199
|
+
|
200
|
+
def prepare_attn_masks(
|
201
|
+
self,
|
202
|
+
input_ids: torch.Tensor,
|
203
|
+
positions: torch.Tensor,
|
204
|
+
mask_dtype: torch.dtype,
|
205
|
+
**kwargs,
|
206
|
+
) -> Dict:
|
207
|
+
"""Prepare attention masks for multimodal inputs."""
|
208
|
+
kwargs["has_images"] = True
|
209
|
+
|
210
|
+
# Distinguish sequences by position id 0
|
211
|
+
start_indices = (positions == 0).cpu().nonzero()
|
212
|
+
num_seqs = len(start_indices)
|
213
|
+
seq_lens = []
|
214
|
+
|
215
|
+
for i in range(num_seqs):
|
216
|
+
start_idx = start_indices[i].item()
|
217
|
+
if i < num_seqs - 1:
|
218
|
+
end_idx = start_indices[i + 1].item()
|
219
|
+
else:
|
220
|
+
end_idx = len(input_ids)
|
221
|
+
seq_lens.append(end_idx - start_idx)
|
222
|
+
|
223
|
+
kwargs["seq_lens"] = seq_lens
|
224
|
+
|
225
|
+
# Create attention masks
|
226
|
+
global_attn_masks = []
|
227
|
+
local_attn_masks = []
|
228
|
+
sliding_window = self.config.text_config.interleaved_sliding_window
|
229
|
+
|
230
|
+
start_idx = 0
|
231
|
+
for seq_len in seq_lens:
|
232
|
+
end_idx = start_idx + seq_len
|
233
|
+
input_token_ids = input_ids[start_idx:end_idx]
|
234
|
+
start_idx = end_idx
|
235
|
+
|
236
|
+
# Create global causal mask
|
237
|
+
global_attn_mask = torch.empty(
|
238
|
+
1,
|
239
|
+
1,
|
240
|
+
seq_len,
|
241
|
+
seq_len,
|
242
|
+
dtype=mask_dtype,
|
243
|
+
device=input_ids.device,
|
244
|
+
)
|
245
|
+
global_attn_mask.fill_(float("-inf"))
|
246
|
+
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
247
|
+
|
248
|
+
# Consider bidirectional attention between image tokens
|
249
|
+
img_mask = torch.zeros_like(global_attn_mask)
|
250
|
+
img_pos = input_token_ids == self.config.image_token_index
|
251
|
+
img_mask[:, :, :, img_pos] += 1
|
252
|
+
img_mask[:, :, img_pos, :] += 1
|
253
|
+
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
254
|
+
global_attn_masks.append(global_attn_mask)
|
255
|
+
|
256
|
+
# Create local causal mask with sliding window
|
257
|
+
local_attn_mask = torch.ones_like(global_attn_mask)
|
258
|
+
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
259
|
+
local_attn_mask = torch.where(
|
260
|
+
local_attn_mask == 0, global_attn_mask, float("-inf")
|
261
|
+
)
|
262
|
+
local_attn_masks.append(local_attn_mask)
|
263
|
+
|
264
|
+
kwargs["global_attn_masks"] = global_attn_masks
|
265
|
+
kwargs["local_attn_masks"] = local_attn_masks
|
266
|
+
return kwargs
|
267
|
+
|
268
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
269
|
+
return self.language_model.get_input_embeddings()
|
270
|
+
|
271
|
+
def get_image_feature(self, image_input: MultimodalInputs):
|
272
|
+
"""
|
273
|
+
Projects the last hidden state from the vision model into language model space.
|
274
|
+
|
275
|
+
Args:
|
276
|
+
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
277
|
+
The tensors corresponding to the input images.
|
278
|
+
Returns:
|
279
|
+
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
280
|
+
"""
|
281
|
+
pixel_values = image_input.pixel_values
|
282
|
+
pixel_values = pixel_values.to("cuda")
|
283
|
+
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
284
|
+
|
285
|
+
vision_outputs = self.vision_tower(pixel_values=pixel_values).last_hidden_state
|
286
|
+
image_features = self.multi_modal_projector(vision_outputs)
|
287
|
+
return image_features
|
288
|
+
|
289
|
+
def embed_mm_inputs(
|
290
|
+
self,
|
291
|
+
input_ids: torch.Tensor,
|
292
|
+
forward_batch: ForwardBatch,
|
293
|
+
image_input: MultimodalInputs,
|
294
|
+
) -> torch.Tensor:
|
295
|
+
if input_ids is None:
|
296
|
+
raise ValueError("Unimplemented")
|
297
|
+
# boolean-masking image tokens
|
298
|
+
special_image_mask = torch.isin(
|
299
|
+
input_ids,
|
300
|
+
torch.tensor(image_input.pad_values, device=input_ids.device),
|
301
|
+
).unsqueeze(-1)
|
302
|
+
num_image_tokens_in_input_ids = special_image_mask.sum()
|
303
|
+
|
304
|
+
inputs_embeds = None
|
305
|
+
if num_image_tokens_in_input_ids == 0:
|
306
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
307
|
+
return inputs_embeds
|
308
|
+
else:
|
309
|
+
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
310
|
+
image_features = self.get_image_feature(image_input.pixel_values)
|
311
|
+
|
312
|
+
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
313
|
+
num_image_tokens_in_embedding = (
|
314
|
+
image_features.shape[0] * image_features.shape[1]
|
315
|
+
)
|
316
|
+
|
317
|
+
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
318
|
+
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
|
319
|
+
image_features = image_features[:num_image, :]
|
320
|
+
logger.warning(
|
321
|
+
f"Number of images does not match number of special image tokens in the input text. "
|
322
|
+
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
323
|
+
"tokens from image embeddings."
|
324
|
+
)
|
325
|
+
|
326
|
+
# Important: clamp after extracting original image boundaries
|
327
|
+
input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
328
|
+
|
329
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
330
|
+
|
331
|
+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
332
|
+
inputs_embeds.device
|
333
|
+
)
|
334
|
+
|
335
|
+
image_features = image_features.to(
|
336
|
+
inputs_embeds.device, inputs_embeds.dtype
|
337
|
+
)
|
338
|
+
inputs_embeds = inputs_embeds.masked_scatter(
|
339
|
+
special_image_mask, image_features
|
340
|
+
)
|
341
|
+
|
342
|
+
return inputs_embeds
|
343
|
+
|
344
|
+
@torch.no_grad()
|
345
|
+
def forward(
|
346
|
+
self,
|
347
|
+
input_ids: torch.LongTensor,
|
348
|
+
positions: torch.Tensor,
|
349
|
+
forward_batch: ForwardBatch,
|
350
|
+
input_embeds: torch.Tensor = None,
|
351
|
+
**kwargs: object,
|
352
|
+
) -> LogitsProcessor:
|
353
|
+
r"""
|
354
|
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
355
|
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
356
|
+
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
357
|
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
358
|
+
|
359
|
+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
360
|
+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
361
|
+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
362
|
+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
363
|
+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
364
|
+
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
|
368
|
+
Example:
|
369
|
+
|
370
|
+
```python
|
371
|
+
>>> from PIL import Image
|
372
|
+
>>> import requests
|
373
|
+
>>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
374
|
+
|
375
|
+
>>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf")
|
376
|
+
>>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf")
|
377
|
+
|
378
|
+
>>> prompt = "answer en Where is the cow standing?"
|
379
|
+
>>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png"
|
380
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
381
|
+
|
382
|
+
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
383
|
+
|
384
|
+
>>> # Generate
|
385
|
+
>>> generate_ids = model.generate(**inputs, max_length=30)
|
386
|
+
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
387
|
+
"answer en Where is the cow standing?\nbeach"
|
388
|
+
```"""
|
389
|
+
|
390
|
+
# Important: position_ids in Gemma3 are 1-indexed
|
391
|
+
# This really does cost me sometime
|
392
|
+
positions += 1
|
393
|
+
|
394
|
+
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
395
|
+
if input_ids is not None and self.config.image_token_index >= self.vocab_size:
|
396
|
+
special_image_mask = input_ids == self.config.image_token_index
|
397
|
+
llm_input_ids = input_ids.clone()
|
398
|
+
llm_input_ids[special_image_mask] = 0
|
399
|
+
else:
|
400
|
+
llm_input_ids = input_ids
|
401
|
+
|
402
|
+
inputs_embeds = general_mm_embed_routine(
|
403
|
+
input_ids=llm_input_ids,
|
404
|
+
forward_batch=forward_batch,
|
405
|
+
embed_tokens=self.get_input_embeddings(),
|
406
|
+
mm_data_embedding_func=self.get_image_feature,
|
407
|
+
)
|
408
|
+
|
409
|
+
outputs = self.language_model(
|
410
|
+
input_ids=None,
|
411
|
+
positions=positions,
|
412
|
+
forward_batch=forward_batch,
|
413
|
+
input_embeds=inputs_embeds,
|
414
|
+
**kwargs,
|
415
|
+
)
|
416
|
+
|
417
|
+
return outputs
|
418
|
+
|
419
|
+
def tie_weights(self):
|
420
|
+
return self.language_model.tie_weights()
|
421
|
+
|
422
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
423
|
+
"""Load weights for the model."""
|
424
|
+
params_dict = dict(self.named_parameters())
|
425
|
+
loaded_params: Set[str] = set()
|
426
|
+
|
427
|
+
for name, loaded_weight in weights:
|
428
|
+
if "language_model" in name:
|
429
|
+
# Gemma3ForCausalLM.load_weights(self, [(name.replace("language_model.", ""), loaded_weight)])
|
430
|
+
causal_loaded_params = Gemma3ForCausalLM.load_weights(
|
431
|
+
self, [(name, loaded_weight)]
|
432
|
+
)
|
433
|
+
loaded_params.update(causal_loaded_params)
|
434
|
+
continue
|
435
|
+
else:
|
436
|
+
# Skip lm_head.weight as it's tied with embed_tokens
|
437
|
+
if "lm_head.weight" in name:
|
438
|
+
continue
|
439
|
+
|
440
|
+
# Skip loading extra bias for GPTQ models
|
441
|
+
if name.endswith(".bias") and name not in params_dict:
|
442
|
+
continue
|
443
|
+
|
444
|
+
# Remapping the name of FP8 kv-scale
|
445
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
446
|
+
if name is None:
|
447
|
+
continue
|
448
|
+
param = params_dict[name]
|
449
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
450
|
+
weight_loader(param, loaded_weight)
|
451
|
+
loaded_params.add(name)
|
452
|
+
unloaded_params = params_dict.keys() - loaded_params
|
453
|
+
if unloaded_params:
|
454
|
+
pass
|
455
|
+
# raise RuntimeError(
|
456
|
+
# f"Some weights are not initialized from checkpoints: {unloaded_params}")
|
457
|
+
return loaded_params
|
458
|
+
|
459
|
+
|
460
|
+
EntryClass = Gemma3ForConditionalGeneration
|
461
|
+
|
462
|
+
AutoModel.register(Gemma3Config, Gemma3ForConditionalGeneration, exist_ok=True)
|