ai-edge-torch-nightly 0.4.0.dev20250311__py3-none-any.whl → 0.4.0.dev20250313__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/gemma2.py +3 -0
- ai_edge_torch/generative/examples/gemma3/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +124 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py +96 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py +463 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py +212 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py +149 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +436 -0
- ai_edge_torch/generative/examples/gemma3/gemma3.py +176 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -3
- ai_edge_torch/odml_torch/composite/mark_tensor.py +0 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250313.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250313.dist-info}/RECORD +18 -9
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250313.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250313.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250313.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,149 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an image encoder of Gemma3 model which is Siglip."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
21
|
+
import torch
|
22
|
+
from torch import nn
|
23
|
+
import torch.nn.functional as F
|
24
|
+
|
25
|
+
|
26
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
27
|
+
ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
|
28
|
+
ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
|
29
|
+
attn_query_proj=(
|
30
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
|
31
|
+
),
|
32
|
+
attn_key_proj=(
|
33
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
|
34
|
+
),
|
35
|
+
attn_value_proj=(
|
36
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
|
37
|
+
),
|
38
|
+
attn_output_proj=(
|
39
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
|
40
|
+
),
|
41
|
+
pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
|
42
|
+
embedding="vision_tower.vision_model.embeddings.patch_embedding",
|
43
|
+
embedding_position=(
|
44
|
+
"vision_tower.vision_model.embeddings.position_embedding.weight"
|
45
|
+
),
|
46
|
+
final_norm="vision_tower.vision_model.post_layernorm",
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class SiglipExit(nn.Module):
|
51
|
+
"""Siglip exit layer."""
|
52
|
+
|
53
|
+
def __init__(self, config: cfg.ModelConfig):
|
54
|
+
super().__init__()
|
55
|
+
self.expected_length = config.num_mm_tokens_per_image**0.5
|
56
|
+
|
57
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
58
|
+
current_tokens = x.shape[1]
|
59
|
+
current_length = int(current_tokens**0.5)
|
60
|
+
if current_length != self.expected_length:
|
61
|
+
window_size = int(current_length // self.expected_length)
|
62
|
+
x = x.transpose(1, 2)
|
63
|
+
x = x.view(x.shape[0], x.shape[1], current_length, current_length)
|
64
|
+
x = F.avg_pool2d(x, window_size, stride=window_size)
|
65
|
+
x = x.view(x.shape[0], x.shape[1], -1)
|
66
|
+
x = x.transpose(1, 2)
|
67
|
+
return x
|
68
|
+
|
69
|
+
class SiglipVisionEncoderWithExit(nn.Module):
|
70
|
+
"""Siglip vision encoder for Gemma3MM from the Edge Generative API."""
|
71
|
+
|
72
|
+
def __init__(self, config: cfg.ModelConfig):
|
73
|
+
super().__init__()
|
74
|
+
self.siglip_encoder = image_encoder.SiglipVisionEncoder(config)
|
75
|
+
self.siglip_exit = SiglipExit(config)
|
76
|
+
|
77
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
78
|
+
x = self.siglip_encoder(pixel_values)
|
79
|
+
x = self.siglip_exit(x)
|
80
|
+
return x
|
81
|
+
|
82
|
+
def get_image_encoder_config() -> cfg.ModelConfig:
|
83
|
+
"""Returns the model config for the image encoder of a Gemma3 4B model.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The model config for the image encoder of a Gemma3 4B model.
|
87
|
+
"""
|
88
|
+
image_embedding_config = cfg.ImageEmbeddingConfig(
|
89
|
+
channels=3,
|
90
|
+
image_size=896,
|
91
|
+
patch_size=14,
|
92
|
+
)
|
93
|
+
attn_config = cfg.AttentionConfig(
|
94
|
+
num_heads=16,
|
95
|
+
head_dim=72,
|
96
|
+
num_query_groups=16,
|
97
|
+
qkv_use_bias=True,
|
98
|
+
output_proj_use_bias=True,
|
99
|
+
)
|
100
|
+
norm_config = cfg.NormalizationConfig(
|
101
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
102
|
+
epsilon=1e-6,
|
103
|
+
enable_hlfb=True,
|
104
|
+
)
|
105
|
+
ff_config = cfg.FeedForwardConfig(
|
106
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
107
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
108
|
+
intermediate_size=4304,
|
109
|
+
use_bias=True,
|
110
|
+
pre_ff_norm_config=norm_config,
|
111
|
+
)
|
112
|
+
block_config = cfg.TransformerBlockConfig(
|
113
|
+
attn_config=attn_config,
|
114
|
+
ff_config=ff_config,
|
115
|
+
pre_attention_norm_config=norm_config,
|
116
|
+
)
|
117
|
+
config = cfg.ModelConfig(
|
118
|
+
vocab_size=0, # Not used in image encoder.
|
119
|
+
num_layers=27,
|
120
|
+
max_seq_len=0, # Not used in image encoder.
|
121
|
+
embedding_dim=1152,
|
122
|
+
embedding_use_bias=True,
|
123
|
+
image_embedding=image_embedding_config,
|
124
|
+
block_configs=block_config,
|
125
|
+
final_norm_config=norm_config,
|
126
|
+
enable_hlfb=True,
|
127
|
+
num_mm_tokens_per_image=256,
|
128
|
+
)
|
129
|
+
return config
|
130
|
+
|
131
|
+
|
132
|
+
def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
133
|
+
config = get_image_encoder_config()
|
134
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
135
|
+
config.image_embedding.image_size = 8
|
136
|
+
config.image_embedding.patch_size = 2
|
137
|
+
config.num_layers = 2
|
138
|
+
config.num_mm_tokens_per_image = 4
|
139
|
+
return config
|
140
|
+
|
141
|
+
|
142
|
+
def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoderWithExit:
|
143
|
+
config = get_image_encoder_config()
|
144
|
+
encoder = SiglipVisionEncoderWithExit(config).siglip_encoder
|
145
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
146
|
+
# Loose the strictness because only image encoder is being loaded.
|
147
|
+
loader.load(encoder, strict=False)
|
148
|
+
encoder.eval()
|
149
|
+
return encoder
|
@@ -0,0 +1,436 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a Decoder for Gemma3 model."""
|
17
|
+
|
18
|
+
from typing import List, Optional, Tuple
|
19
|
+
|
20
|
+
from ai_edge_torch.generative.layers import builder
|
21
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
22
|
+
from ai_edge_torch.generative.layers.experimental import attention
|
23
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
|
24
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
|
+
from ai_edge_torch.generative.utilities import model_builder
|
27
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
|
+
import torch
|
29
|
+
from torch import nn
|
30
|
+
|
31
|
+
|
32
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
33
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
34
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
35
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
36
|
+
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
37
|
+
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
38
|
+
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
39
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
40
|
+
attn_query_norm="model.layers.{}.self_attn.q_norm",
|
41
|
+
attn_key_norm="model.layers.{}.self_attn.k_norm",
|
42
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
43
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
44
|
+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
45
|
+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
46
|
+
embedding="model.embed_tokens",
|
47
|
+
final_norm="model.norm",
|
48
|
+
lm_head=None,
|
49
|
+
)
|
50
|
+
|
51
|
+
# Please don't use tensor mapping for converting checkpoints hosted on Kaggle
|
52
|
+
# or HuggingFace. Will be removed in the future.
|
53
|
+
TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
|
54
|
+
ff_up_proj="model.layers.{}.mlp.up_proj",
|
55
|
+
ff_down_proj="model.layers.{}.mlp.down_proj",
|
56
|
+
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
57
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
58
|
+
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
59
|
+
attn_query_norm="model.layers.{}.self_attn.query_norm",
|
60
|
+
attn_key_norm="model.layers.{}.self_attn.key_norm",
|
61
|
+
pre_attn_norm="model.layers.{}.input_layernorm",
|
62
|
+
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
63
|
+
pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
|
64
|
+
post_ff_norm="model.layers.{}.post_feedforward_layernorm",
|
65
|
+
embedding="embedder",
|
66
|
+
final_norm="model.norm",
|
67
|
+
lm_head=None,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
class DecoderBlock(attention.TransformerBlock):
|
72
|
+
|
73
|
+
def forward(
|
74
|
+
self,
|
75
|
+
x: torch.Tensor,
|
76
|
+
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
77
|
+
mask: Optional[torch.Tensor] = None,
|
78
|
+
input_pos: Optional[torch.Tensor] = None,
|
79
|
+
kv_cache: kv_utils.KVCacheEntryBase = None,
|
80
|
+
) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntryBase]]:
|
81
|
+
"""Forward function of the Gemma3Block.
|
82
|
+
|
83
|
+
Exactly the same as TransformerBlock but we call the post-attention norm
|
84
|
+
immediately after attention and not after the residual pointwise addition.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
x (torch.Tensor): the input tensor.
|
88
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
89
|
+
mask (torch.Tensor): the optional mask tensor.
|
90
|
+
input_pos (torch.Tensor): the optional input position tensor.
|
91
|
+
kv_cache (KVCacheEntry): the optional kv cache entry.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
output activation from this transformer block, and updated kv cache (if
|
95
|
+
passed in).
|
96
|
+
"""
|
97
|
+
|
98
|
+
x_norm = self.pre_atten_norm(x)
|
99
|
+
attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
|
100
|
+
attn_out_norm = self.post_atten_norm(attn_out)
|
101
|
+
x = x + attn_out_norm
|
102
|
+
output = x + self.ff(x)
|
103
|
+
return output, kv
|
104
|
+
|
105
|
+
|
106
|
+
class Decoder(nn.Module):
|
107
|
+
"""A Gemma3 decoder model built from the Edge Generative API layers."""
|
108
|
+
|
109
|
+
def __init__(self, config: cfg.ModelConfig):
|
110
|
+
super().__init__()
|
111
|
+
|
112
|
+
# Construct model layers.
|
113
|
+
self.tok_embedding = nn.Embedding(
|
114
|
+
config.vocab_size, config.embedding_dim, padding_idx=0
|
115
|
+
)
|
116
|
+
self.lm_head = nn.Linear(
|
117
|
+
config.embedding_dim,
|
118
|
+
config.vocab_size,
|
119
|
+
bias=config.lm_head_use_bias,
|
120
|
+
)
|
121
|
+
# Gemma3 re-uses the embedding as the head projection layer.
|
122
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
123
|
+
self.transformer_blocks = nn.ModuleList(
|
124
|
+
DecoderBlock(config.block_config(idx), config)
|
125
|
+
for idx in range(config.num_layers)
|
126
|
+
)
|
127
|
+
self.final_norm = builder.build_norm(
|
128
|
+
config.embedding_dim,
|
129
|
+
config.final_norm_config,
|
130
|
+
)
|
131
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
132
|
+
size=config.kv_cache_max,
|
133
|
+
)
|
134
|
+
# Gemma3 has same hyper parameters for each layer except for attention
|
135
|
+
# types. Use the first layer.
|
136
|
+
attn_config = config.block_config(0).attn_config
|
137
|
+
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
138
|
+
size=config.kv_cache_max,
|
139
|
+
window_size=attn_config.sliding_window_size,
|
140
|
+
)
|
141
|
+
self.config = config
|
142
|
+
|
143
|
+
def get_attention_mask(
|
144
|
+
self,
|
145
|
+
attn_type: cfg.AttentionType,
|
146
|
+
input_pos: torch.Tensor,
|
147
|
+
) -> torch.Tensor:
|
148
|
+
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
149
|
+
return self.sliding_window_mask_cache.index_select(2, input_pos)
|
150
|
+
return self.mask_cache.index_select(2, input_pos)
|
151
|
+
|
152
|
+
def get_local_global_attention_mask(
|
153
|
+
self,
|
154
|
+
attention_mask: torch.Tensor,
|
155
|
+
attn_type: cfg.AttentionType,
|
156
|
+
segment_pos: torch.Tensor,
|
157
|
+
sliding_window_size: int,
|
158
|
+
) -> torch.Tensor:
|
159
|
+
"""Returns the attention mask for the current batch (PyTorch)."""
|
160
|
+
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
161
|
+
sliding_mask = self.create_sliding_mask(
|
162
|
+
segment_pos=segment_pos,
|
163
|
+
cache_len=attention_mask.shape[-1],
|
164
|
+
sliding_window_size=sliding_window_size,
|
165
|
+
)
|
166
|
+
# Combine masks using logical AND (min in this case).
|
167
|
+
combined_mask = torch.min(attention_mask, sliding_mask)
|
168
|
+
return combined_mask
|
169
|
+
return attention_mask
|
170
|
+
|
171
|
+
def create_sliding_mask(
|
172
|
+
self,
|
173
|
+
segment_pos: torch.Tensor, # [B, L]
|
174
|
+
cache_len: int,
|
175
|
+
sliding_window_size: int,
|
176
|
+
) -> torch.Tensor:
|
177
|
+
"""Creates mask for sliding window attention (PyTorch)."""
|
178
|
+
cache_positions = torch.tensor(
|
179
|
+
[i for i in range(cache_len)], dtype=torch.int32
|
180
|
+
)
|
181
|
+
cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
|
182
|
+
segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
|
183
|
+
|
184
|
+
# Create boolean masks for window boundaries.
|
185
|
+
left_boundary = cache_positions > segment_pos_expanded - sliding_window_size
|
186
|
+
right_boundary = (
|
187
|
+
cache_positions < segment_pos_expanded + sliding_window_size
|
188
|
+
)
|
189
|
+
|
190
|
+
# Combine boolean masks (AND).
|
191
|
+
sliding_mask_bool = left_boundary & right_boundary
|
192
|
+
|
193
|
+
# Convert boolean mask to float mask with 0 and -inf.
|
194
|
+
sliding_mask = torch.where(
|
195
|
+
sliding_mask_bool,
|
196
|
+
torch.zeros_like(sliding_mask_bool, dtype=torch.float),
|
197
|
+
torch.full_like(sliding_mask_bool, float("-inf"), dtype=torch.float),
|
198
|
+
)
|
199
|
+
|
200
|
+
return sliding_mask
|
201
|
+
|
202
|
+
def compose_mask(
|
203
|
+
self,
|
204
|
+
mask: torch.Tensor,
|
205
|
+
pixel_mask: torch.Tensor,
|
206
|
+
attn_type: cfg.AttentionType,
|
207
|
+
) -> torch.Tensor:
|
208
|
+
mask = mask == 0
|
209
|
+
if attn_type == cfg.AttentionType.LOCAL_SLIDING:
|
210
|
+
mask = torch.logical_and(mask, pixel_mask)
|
211
|
+
else:
|
212
|
+
mask = torch.logical_or(mask, pixel_mask)
|
213
|
+
mask = torch.where(mask, 0, float("-inf"))
|
214
|
+
return mask
|
215
|
+
|
216
|
+
def build_pixel_mask(self, image_indices: torch.Tensor):
|
217
|
+
pixel_mask = image_indices >= 0
|
218
|
+
max_seq_len = self.config.kv_cache_max
|
219
|
+
if pixel_mask.size(1) < max_seq_len:
|
220
|
+
pixel_mask = torch.cat(
|
221
|
+
[
|
222
|
+
pixel_mask,
|
223
|
+
torch.zeros(
|
224
|
+
(pixel_mask.size(0), max_seq_len - pixel_mask.size(1))
|
225
|
+
),
|
226
|
+
],
|
227
|
+
dim=1,
|
228
|
+
)
|
229
|
+
pixel_mask = torch.logical_and(
|
230
|
+
pixel_mask.unsqueeze(1), pixel_mask.unsqueeze(-1)
|
231
|
+
)
|
232
|
+
return pixel_mask.unsqueeze(1)
|
233
|
+
|
234
|
+
@torch.inference_mode
|
235
|
+
def forward(
|
236
|
+
self,
|
237
|
+
tokens: torch.Tensor,
|
238
|
+
input_pos: torch.Tensor,
|
239
|
+
kv_cache: kv_utils.KVCacheBase,
|
240
|
+
input_embeds: Optional[torch.Tensor] = None,
|
241
|
+
mask: Optional[torch.Tensor] = None,
|
242
|
+
image_indices: Optional[torch.Tensor] = None,
|
243
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
244
|
+
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
|
245
|
+
|
246
|
+
pixel_mask = None
|
247
|
+
if input_embeds is None:
|
248
|
+
# token embeddings of shape (b, t, n_embd)
|
249
|
+
input_embeds = self.tok_embedding(tokens)
|
250
|
+
if self.config.embedding_scale is not None:
|
251
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
252
|
+
if image_indices is not None:
|
253
|
+
pixel_mask = self.build_pixel_mask(image_indices)
|
254
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
255
|
+
attn_config = self.config.block_config(0).attn_config
|
256
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
257
|
+
# Different rotary base for global and local attention
|
258
|
+
# based on attention pattern
|
259
|
+
rope = [
|
260
|
+
rotary_pos_emb.build_rope(
|
261
|
+
input_pos,
|
262
|
+
attn_config.head_dim,
|
263
|
+
self.config.block_config(i).attn_config.rotary_base,
|
264
|
+
)
|
265
|
+
for i in range(self.config.num_layers)
|
266
|
+
]
|
267
|
+
if mask is None:
|
268
|
+
mask = [
|
269
|
+
self.get_attention_mask(
|
270
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
271
|
+
)
|
272
|
+
for i in range(self.config.num_layers)
|
273
|
+
]
|
274
|
+
|
275
|
+
return self._forward_with_embeds(
|
276
|
+
input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
|
277
|
+
)
|
278
|
+
|
279
|
+
def _forward_with_embeds(
|
280
|
+
self,
|
281
|
+
input_embeds: torch.Tensor,
|
282
|
+
rope: List[Tuple[torch.Tensor, torch.Tensor]],
|
283
|
+
mask: torch.Tensor | List[torch.Tensor],
|
284
|
+
input_pos: torch.Tensor,
|
285
|
+
kv_cache: kv_utils.KVCacheBase,
|
286
|
+
pixel_mask: Optional[torch.Tensor] = None,
|
287
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
288
|
+
) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
|
289
|
+
"""Forwards the model with input embeddings."""
|
290
|
+
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
291
|
+
"The number of transformer blocks and the number of KV cache entries"
|
292
|
+
" must be the same."
|
293
|
+
)
|
294
|
+
|
295
|
+
x = input_embeds
|
296
|
+
|
297
|
+
if pixel_mask is None:
|
298
|
+
mask = [
|
299
|
+
self.get_local_global_attention_mask(
|
300
|
+
mask,
|
301
|
+
self.config.block_config(i).attn_config.attn_type,
|
302
|
+
input_pos,
|
303
|
+
self.config.block_config(i).attn_config.sliding_window_size,
|
304
|
+
)
|
305
|
+
for i in range(self.config.num_layers)
|
306
|
+
]
|
307
|
+
else:
|
308
|
+
pixel_mask = pixel_mask.index_select(2, input_pos)
|
309
|
+
mask = [
|
310
|
+
self.compose_mask(
|
311
|
+
mask[i],
|
312
|
+
pixel_mask,
|
313
|
+
self.config.block_config(i).attn_config.attn_type,
|
314
|
+
)
|
315
|
+
for i in range(self.config.num_layers)
|
316
|
+
]
|
317
|
+
updated_kv_entries = []
|
318
|
+
for i, block in enumerate(self.transformer_blocks):
|
319
|
+
mask_entry = mask[i] if isinstance(mask, list) else mask
|
320
|
+
kv_entry = kv_cache.caches[i] if kv_cache else None
|
321
|
+
x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
|
322
|
+
if kv_entry:
|
323
|
+
updated_kv_entries.append(kv_entry)
|
324
|
+
updated_kv_cache = kv_utils.KVCacheBase(tuple(updated_kv_entries))
|
325
|
+
if export_config is not None:
|
326
|
+
if (
|
327
|
+
torch.numel(input_pos) > 1
|
328
|
+
and not export_config.output_logits_on_prefill
|
329
|
+
):
|
330
|
+
return {"kv_cache": updated_kv_cache}
|
331
|
+
|
332
|
+
x = self.final_norm(x)
|
333
|
+
res = self.lm_head(x) # (b, t, vocab_size)
|
334
|
+
|
335
|
+
return {"logits": res, "kv_cache": updated_kv_cache}
|
336
|
+
|
337
|
+
|
338
|
+
def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
|
339
|
+
"""Returns the model config for a Gemma3 1B model.
|
340
|
+
|
341
|
+
Args:
|
342
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
343
|
+
is 2048.
|
344
|
+
|
345
|
+
Returns:
|
346
|
+
The model config for a Gemma 1B model.
|
347
|
+
"""
|
348
|
+
norm_config = cfg.NormalizationConfig(
|
349
|
+
type=cfg.NormalizationType.RMS_NORM,
|
350
|
+
epsilon=1e-6,
|
351
|
+
zero_centered=True,
|
352
|
+
enable_hlfb=True,
|
353
|
+
)
|
354
|
+
ff_config = cfg.FeedForwardConfig(
|
355
|
+
type=cfg.FeedForwardType.GATED,
|
356
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
357
|
+
intermediate_size=6 * 1152,
|
358
|
+
pre_ff_norm_config=norm_config,
|
359
|
+
post_ff_norm_config=norm_config,
|
360
|
+
)
|
361
|
+
|
362
|
+
def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
|
363
|
+
attn_config = cfg.AttentionConfig(
|
364
|
+
num_heads=4,
|
365
|
+
head_dim=256,
|
366
|
+
num_query_groups=1,
|
367
|
+
rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
|
368
|
+
rotary_percentage=1.0,
|
369
|
+
qkv_transpose_before_split=True,
|
370
|
+
query_norm_config=norm_config,
|
371
|
+
key_norm_config=norm_config,
|
372
|
+
logit_softcap=None,
|
373
|
+
sliding_window_size=512,
|
374
|
+
attn_type=(
|
375
|
+
cfg.AttentionType.GLOBAL
|
376
|
+
if (idx + 1) % 6 == 0
|
377
|
+
else cfg.AttentionType.LOCAL_SLIDING
|
378
|
+
),
|
379
|
+
)
|
380
|
+
return cfg.TransformerBlockConfig(
|
381
|
+
attn_config=attn_config,
|
382
|
+
ff_config=ff_config,
|
383
|
+
pre_attention_norm_config=norm_config,
|
384
|
+
post_attention_norm_config=norm_config,
|
385
|
+
)
|
386
|
+
|
387
|
+
num_layers = 26
|
388
|
+
embedding_dim = 1152
|
389
|
+
config = cfg.ModelConfig(
|
390
|
+
vocab_size=262_144,
|
391
|
+
num_layers=num_layers,
|
392
|
+
max_seq_len=32_768,
|
393
|
+
embedding_dim=embedding_dim,
|
394
|
+
embedding_scale=embedding_dim**0.5,
|
395
|
+
kv_cache_max_len=kv_cache_max_len,
|
396
|
+
block_configs=[get_block_config(i) for i in range(num_layers)],
|
397
|
+
final_norm_config=norm_config,
|
398
|
+
lm_head_use_bias=False,
|
399
|
+
enable_hlfb=True,
|
400
|
+
final_logit_softcap=None,
|
401
|
+
)
|
402
|
+
return config
|
403
|
+
|
404
|
+
|
405
|
+
def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
406
|
+
"""Returns a fake model config for a Gemma3 1B model.
|
407
|
+
|
408
|
+
Args:
|
409
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
410
|
+
is 128.
|
411
|
+
|
412
|
+
Returns:
|
413
|
+
A fake model config for a Gemma 1B model.
|
414
|
+
"""
|
415
|
+
config = get_decoder_config_1b(kv_cache_max_len)
|
416
|
+
config.vocab_size = 128
|
417
|
+
config.num_layers = 2
|
418
|
+
config.max_seq_len = 2 * kv_cache_max_len
|
419
|
+
config.embedding_dim = 128
|
420
|
+
config.embedding_scale = config.embedding_dim**0.5
|
421
|
+
config.block_configs = config.block_configs[: config.num_layers]
|
422
|
+
for block_config in config.block_configs:
|
423
|
+
block_config.attn_config.num_heads = 4
|
424
|
+
block_config.attn_config.head_dim = 64
|
425
|
+
block_config.attn_config.sliding_window_size = 64
|
426
|
+
block_config.ff_config.intermediate_size = 128
|
427
|
+
return config
|
428
|
+
|
429
|
+
|
430
|
+
def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
|
431
|
+
return model_builder.build_decoder_only_model(
|
432
|
+
checkpoint_path=checkpoint_path,
|
433
|
+
config=get_decoder_config_1b(**kwargs),
|
434
|
+
tensor_names=TENSOR_NAMES,
|
435
|
+
model_class=Decoder,
|
436
|
+
)
|