ai-edge-torch-nightly 0.4.0.dev20250311__py3-none-any.whl → 0.4.0.dev20250313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,149 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building an image encoder of Gemma3 model which is Siglip."""
17
+
18
+ from ai_edge_torch.generative.examples.paligemma import image_encoder
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ import torch
22
+ from torch import nn
23
+ import torch.nn.functional as F
24
+
25
+
26
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
27
+ ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
28
+ ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
29
+ attn_query_proj=(
30
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
31
+ ),
32
+ attn_key_proj=(
33
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
34
+ ),
35
+ attn_value_proj=(
36
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
37
+ ),
38
+ attn_output_proj=(
39
+ "vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
40
+ ),
41
+ pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
42
+ embedding="vision_tower.vision_model.embeddings.patch_embedding",
43
+ embedding_position=(
44
+ "vision_tower.vision_model.embeddings.position_embedding.weight"
45
+ ),
46
+ final_norm="vision_tower.vision_model.post_layernorm",
47
+ )
48
+
49
+
50
+ class SiglipExit(nn.Module):
51
+ """Siglip exit layer."""
52
+
53
+ def __init__(self, config: cfg.ModelConfig):
54
+ super().__init__()
55
+ self.expected_length = config.num_mm_tokens_per_image**0.5
56
+
57
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
58
+ current_tokens = x.shape[1]
59
+ current_length = int(current_tokens**0.5)
60
+ if current_length != self.expected_length:
61
+ window_size = int(current_length // self.expected_length)
62
+ x = x.transpose(1, 2)
63
+ x = x.view(x.shape[0], x.shape[1], current_length, current_length)
64
+ x = F.avg_pool2d(x, window_size, stride=window_size)
65
+ x = x.view(x.shape[0], x.shape[1], -1)
66
+ x = x.transpose(1, 2)
67
+ return x
68
+
69
+ class SiglipVisionEncoderWithExit(nn.Module):
70
+ """Siglip vision encoder for Gemma3MM from the Edge Generative API."""
71
+
72
+ def __init__(self, config: cfg.ModelConfig):
73
+ super().__init__()
74
+ self.siglip_encoder = image_encoder.SiglipVisionEncoder(config)
75
+ self.siglip_exit = SiglipExit(config)
76
+
77
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
78
+ x = self.siglip_encoder(pixel_values)
79
+ x = self.siglip_exit(x)
80
+ return x
81
+
82
+ def get_image_encoder_config() -> cfg.ModelConfig:
83
+ """Returns the model config for the image encoder of a Gemma3 4B model.
84
+
85
+ Returns:
86
+ The model config for the image encoder of a Gemma3 4B model.
87
+ """
88
+ image_embedding_config = cfg.ImageEmbeddingConfig(
89
+ channels=3,
90
+ image_size=896,
91
+ patch_size=14,
92
+ )
93
+ attn_config = cfg.AttentionConfig(
94
+ num_heads=16,
95
+ head_dim=72,
96
+ num_query_groups=16,
97
+ qkv_use_bias=True,
98
+ output_proj_use_bias=True,
99
+ )
100
+ norm_config = cfg.NormalizationConfig(
101
+ type=cfg.NormalizationType.LAYER_NORM,
102
+ epsilon=1e-6,
103
+ enable_hlfb=True,
104
+ )
105
+ ff_config = cfg.FeedForwardConfig(
106
+ type=cfg.FeedForwardType.SEQUENTIAL,
107
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
108
+ intermediate_size=4304,
109
+ use_bias=True,
110
+ pre_ff_norm_config=norm_config,
111
+ )
112
+ block_config = cfg.TransformerBlockConfig(
113
+ attn_config=attn_config,
114
+ ff_config=ff_config,
115
+ pre_attention_norm_config=norm_config,
116
+ )
117
+ config = cfg.ModelConfig(
118
+ vocab_size=0, # Not used in image encoder.
119
+ num_layers=27,
120
+ max_seq_len=0, # Not used in image encoder.
121
+ embedding_dim=1152,
122
+ embedding_use_bias=True,
123
+ image_embedding=image_embedding_config,
124
+ block_configs=block_config,
125
+ final_norm_config=norm_config,
126
+ enable_hlfb=True,
127
+ num_mm_tokens_per_image=256,
128
+ )
129
+ return config
130
+
131
+
132
+ def get_fake_image_encoder_config() -> cfg.ModelConfig:
133
+ config = get_image_encoder_config()
134
+ config.block_config(0).ff_config.intermediate_size = 128
135
+ config.image_embedding.image_size = 8
136
+ config.image_embedding.patch_size = 2
137
+ config.num_layers = 2
138
+ config.num_mm_tokens_per_image = 4
139
+ return config
140
+
141
+
142
+ def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoderWithExit:
143
+ config = get_image_encoder_config()
144
+ encoder = SiglipVisionEncoderWithExit(config).siglip_encoder
145
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
146
+ # Loose the strictness because only image encoder is being loaded.
147
+ loader.load(encoder, strict=False)
148
+ encoder.eval()
149
+ return encoder
@@ -0,0 +1,436 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Decoder for Gemma3 model."""
17
+
18
+ from typing import List, Optional, Tuple
19
+
20
+ from ai_edge_torch.generative.layers import builder
21
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
+ from ai_edge_torch.generative.layers.experimental import attention
23
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
26
+ from ai_edge_torch.generative.utilities import model_builder
27
+ import ai_edge_torch.generative.utilities.loader as loading_utils
28
+ import torch
29
+ from torch import nn
30
+
31
+
32
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
+ ff_up_proj="model.layers.{}.mlp.up_proj",
34
+ ff_down_proj="model.layers.{}.mlp.down_proj",
35
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
36
+ attn_query_proj="model.layers.{}.self_attn.q_proj",
37
+ attn_key_proj="model.layers.{}.self_attn.k_proj",
38
+ attn_value_proj="model.layers.{}.self_attn.v_proj",
39
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
40
+ attn_query_norm="model.layers.{}.self_attn.q_norm",
41
+ attn_key_norm="model.layers.{}.self_attn.k_norm",
42
+ pre_attn_norm="model.layers.{}.input_layernorm",
43
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
44
+ pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
45
+ post_ff_norm="model.layers.{}.post_feedforward_layernorm",
46
+ embedding="model.embed_tokens",
47
+ final_norm="model.norm",
48
+ lm_head=None,
49
+ )
50
+
51
+ # Please don't use tensor mapping for converting checkpoints hosted on Kaggle
52
+ # or HuggingFace. Will be removed in the future.
53
+ TENSOR_NAMES_TO_BE_REMOVED = loading_utils.ModelLoader.TensorNames(
54
+ ff_up_proj="model.layers.{}.mlp.up_proj",
55
+ ff_down_proj="model.layers.{}.mlp.down_proj",
56
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
57
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
58
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
59
+ attn_query_norm="model.layers.{}.self_attn.query_norm",
60
+ attn_key_norm="model.layers.{}.self_attn.key_norm",
61
+ pre_attn_norm="model.layers.{}.input_layernorm",
62
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
63
+ pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
64
+ post_ff_norm="model.layers.{}.post_feedforward_layernorm",
65
+ embedding="embedder",
66
+ final_norm="model.norm",
67
+ lm_head=None,
68
+ )
69
+
70
+
71
+ class DecoderBlock(attention.TransformerBlock):
72
+
73
+ def forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
77
+ mask: Optional[torch.Tensor] = None,
78
+ input_pos: Optional[torch.Tensor] = None,
79
+ kv_cache: kv_utils.KVCacheEntryBase = None,
80
+ ) -> Tuple[torch.Tensor, Optional[kv_utils.KVCacheEntryBase]]:
81
+ """Forward function of the Gemma3Block.
82
+
83
+ Exactly the same as TransformerBlock but we call the post-attention norm
84
+ immediately after attention and not after the residual pointwise addition.
85
+
86
+ Args:
87
+ x (torch.Tensor): the input tensor.
88
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
89
+ mask (torch.Tensor): the optional mask tensor.
90
+ input_pos (torch.Tensor): the optional input position tensor.
91
+ kv_cache (KVCacheEntry): the optional kv cache entry.
92
+
93
+ Returns:
94
+ output activation from this transformer block, and updated kv cache (if
95
+ passed in).
96
+ """
97
+
98
+ x_norm = self.pre_atten_norm(x)
99
+ attn_out, kv = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
100
+ attn_out_norm = self.post_atten_norm(attn_out)
101
+ x = x + attn_out_norm
102
+ output = x + self.ff(x)
103
+ return output, kv
104
+
105
+
106
+ class Decoder(nn.Module):
107
+ """A Gemma3 decoder model built from the Edge Generative API layers."""
108
+
109
+ def __init__(self, config: cfg.ModelConfig):
110
+ super().__init__()
111
+
112
+ # Construct model layers.
113
+ self.tok_embedding = nn.Embedding(
114
+ config.vocab_size, config.embedding_dim, padding_idx=0
115
+ )
116
+ self.lm_head = nn.Linear(
117
+ config.embedding_dim,
118
+ config.vocab_size,
119
+ bias=config.lm_head_use_bias,
120
+ )
121
+ # Gemma3 re-uses the embedding as the head projection layer.
122
+ self.lm_head.weight.data = self.tok_embedding.weight.data
123
+ self.transformer_blocks = nn.ModuleList(
124
+ DecoderBlock(config.block_config(idx), config)
125
+ for idx in range(config.num_layers)
126
+ )
127
+ self.final_norm = builder.build_norm(
128
+ config.embedding_dim,
129
+ config.final_norm_config,
130
+ )
131
+ self.mask_cache = attn_utils.build_causal_mask_cache(
132
+ size=config.kv_cache_max,
133
+ )
134
+ # Gemma3 has same hyper parameters for each layer except for attention
135
+ # types. Use the first layer.
136
+ attn_config = config.block_config(0).attn_config
137
+ self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
138
+ size=config.kv_cache_max,
139
+ window_size=attn_config.sliding_window_size,
140
+ )
141
+ self.config = config
142
+
143
+ def get_attention_mask(
144
+ self,
145
+ attn_type: cfg.AttentionType,
146
+ input_pos: torch.Tensor,
147
+ ) -> torch.Tensor:
148
+ if attn_type == cfg.AttentionType.LOCAL_SLIDING:
149
+ return self.sliding_window_mask_cache.index_select(2, input_pos)
150
+ return self.mask_cache.index_select(2, input_pos)
151
+
152
+ def get_local_global_attention_mask(
153
+ self,
154
+ attention_mask: torch.Tensor,
155
+ attn_type: cfg.AttentionType,
156
+ segment_pos: torch.Tensor,
157
+ sliding_window_size: int,
158
+ ) -> torch.Tensor:
159
+ """Returns the attention mask for the current batch (PyTorch)."""
160
+ if attn_type == cfg.AttentionType.LOCAL_SLIDING:
161
+ sliding_mask = self.create_sliding_mask(
162
+ segment_pos=segment_pos,
163
+ cache_len=attention_mask.shape[-1],
164
+ sliding_window_size=sliding_window_size,
165
+ )
166
+ # Combine masks using logical AND (min in this case).
167
+ combined_mask = torch.min(attention_mask, sliding_mask)
168
+ return combined_mask
169
+ return attention_mask
170
+
171
+ def create_sliding_mask(
172
+ self,
173
+ segment_pos: torch.Tensor, # [B, L]
174
+ cache_len: int,
175
+ sliding_window_size: int,
176
+ ) -> torch.Tensor:
177
+ """Creates mask for sliding window attention (PyTorch)."""
178
+ cache_positions = torch.tensor(
179
+ [i for i in range(cache_len)], dtype=torch.int32
180
+ )
181
+ cache_positions = cache_positions.view(1, 1, -1) # [1, 1, cache_len]
182
+ segment_pos_expanded = segment_pos.clone().unsqueeze(-1) # [B, seq_len, 1]
183
+
184
+ # Create boolean masks for window boundaries.
185
+ left_boundary = cache_positions > segment_pos_expanded - sliding_window_size
186
+ right_boundary = (
187
+ cache_positions < segment_pos_expanded + sliding_window_size
188
+ )
189
+
190
+ # Combine boolean masks (AND).
191
+ sliding_mask_bool = left_boundary & right_boundary
192
+
193
+ # Convert boolean mask to float mask with 0 and -inf.
194
+ sliding_mask = torch.where(
195
+ sliding_mask_bool,
196
+ torch.zeros_like(sliding_mask_bool, dtype=torch.float),
197
+ torch.full_like(sliding_mask_bool, float("-inf"), dtype=torch.float),
198
+ )
199
+
200
+ return sliding_mask
201
+
202
+ def compose_mask(
203
+ self,
204
+ mask: torch.Tensor,
205
+ pixel_mask: torch.Tensor,
206
+ attn_type: cfg.AttentionType,
207
+ ) -> torch.Tensor:
208
+ mask = mask == 0
209
+ if attn_type == cfg.AttentionType.LOCAL_SLIDING:
210
+ mask = torch.logical_and(mask, pixel_mask)
211
+ else:
212
+ mask = torch.logical_or(mask, pixel_mask)
213
+ mask = torch.where(mask, 0, float("-inf"))
214
+ return mask
215
+
216
+ def build_pixel_mask(self, image_indices: torch.Tensor):
217
+ pixel_mask = image_indices >= 0
218
+ max_seq_len = self.config.kv_cache_max
219
+ if pixel_mask.size(1) < max_seq_len:
220
+ pixel_mask = torch.cat(
221
+ [
222
+ pixel_mask,
223
+ torch.zeros(
224
+ (pixel_mask.size(0), max_seq_len - pixel_mask.size(1))
225
+ ),
226
+ ],
227
+ dim=1,
228
+ )
229
+ pixel_mask = torch.logical_and(
230
+ pixel_mask.unsqueeze(1), pixel_mask.unsqueeze(-1)
231
+ )
232
+ return pixel_mask.unsqueeze(1)
233
+
234
+ @torch.inference_mode
235
+ def forward(
236
+ self,
237
+ tokens: torch.Tensor,
238
+ input_pos: torch.Tensor,
239
+ kv_cache: kv_utils.KVCacheBase,
240
+ input_embeds: Optional[torch.Tensor] = None,
241
+ mask: Optional[torch.Tensor] = None,
242
+ image_indices: Optional[torch.Tensor] = None,
243
+ export_config: Optional[model_builder.ExportConfig] = None,
244
+ ) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
245
+
246
+ pixel_mask = None
247
+ if input_embeds is None:
248
+ # token embeddings of shape (b, t, n_embd)
249
+ input_embeds = self.tok_embedding(tokens)
250
+ if self.config.embedding_scale is not None:
251
+ input_embeds = input_embeds * self.config.embedding_scale
252
+ if image_indices is not None:
253
+ pixel_mask = self.build_pixel_mask(image_indices)
254
+ # RoPE parameters are the same for all blocks. Use the first layer.
255
+ attn_config = self.config.block_config(0).attn_config
256
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
257
+ # Different rotary base for global and local attention
258
+ # based on attention pattern
259
+ rope = [
260
+ rotary_pos_emb.build_rope(
261
+ input_pos,
262
+ attn_config.head_dim,
263
+ self.config.block_config(i).attn_config.rotary_base,
264
+ )
265
+ for i in range(self.config.num_layers)
266
+ ]
267
+ if mask is None:
268
+ mask = [
269
+ self.get_attention_mask(
270
+ self.config.block_config(i).attn_config.attn_type, input_pos
271
+ )
272
+ for i in range(self.config.num_layers)
273
+ ]
274
+
275
+ return self._forward_with_embeds(
276
+ input_embeds, rope, mask, input_pos, kv_cache, pixel_mask, export_config
277
+ )
278
+
279
+ def _forward_with_embeds(
280
+ self,
281
+ input_embeds: torch.Tensor,
282
+ rope: List[Tuple[torch.Tensor, torch.Tensor]],
283
+ mask: torch.Tensor | List[torch.Tensor],
284
+ input_pos: torch.Tensor,
285
+ kv_cache: kv_utils.KVCacheBase,
286
+ pixel_mask: Optional[torch.Tensor] = None,
287
+ export_config: Optional[model_builder.ExportConfig] = None,
288
+ ) -> dict[torch.Tensor, kv_utils.KVCacheBase]:
289
+ """Forwards the model with input embeddings."""
290
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
291
+ "The number of transformer blocks and the number of KV cache entries"
292
+ " must be the same."
293
+ )
294
+
295
+ x = input_embeds
296
+
297
+ if pixel_mask is None:
298
+ mask = [
299
+ self.get_local_global_attention_mask(
300
+ mask,
301
+ self.config.block_config(i).attn_config.attn_type,
302
+ input_pos,
303
+ self.config.block_config(i).attn_config.sliding_window_size,
304
+ )
305
+ for i in range(self.config.num_layers)
306
+ ]
307
+ else:
308
+ pixel_mask = pixel_mask.index_select(2, input_pos)
309
+ mask = [
310
+ self.compose_mask(
311
+ mask[i],
312
+ pixel_mask,
313
+ self.config.block_config(i).attn_config.attn_type,
314
+ )
315
+ for i in range(self.config.num_layers)
316
+ ]
317
+ updated_kv_entries = []
318
+ for i, block in enumerate(self.transformer_blocks):
319
+ mask_entry = mask[i] if isinstance(mask, list) else mask
320
+ kv_entry = kv_cache.caches[i] if kv_cache else None
321
+ x, kv_entry = block(x, rope[i], mask_entry, input_pos, kv_entry)
322
+ if kv_entry:
323
+ updated_kv_entries.append(kv_entry)
324
+ updated_kv_cache = kv_utils.KVCacheBase(tuple(updated_kv_entries))
325
+ if export_config is not None:
326
+ if (
327
+ torch.numel(input_pos) > 1
328
+ and not export_config.output_logits_on_prefill
329
+ ):
330
+ return {"kv_cache": updated_kv_cache}
331
+
332
+ x = self.final_norm(x)
333
+ res = self.lm_head(x) # (b, t, vocab_size)
334
+
335
+ return {"logits": res, "kv_cache": updated_kv_cache}
336
+
337
+
338
+ def get_decoder_config_1b(kv_cache_max_len: int = 2048) -> cfg.ModelConfig:
339
+ """Returns the model config for a Gemma3 1B model.
340
+
341
+ Args:
342
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
343
+ is 2048.
344
+
345
+ Returns:
346
+ The model config for a Gemma 1B model.
347
+ """
348
+ norm_config = cfg.NormalizationConfig(
349
+ type=cfg.NormalizationType.RMS_NORM,
350
+ epsilon=1e-6,
351
+ zero_centered=True,
352
+ enable_hlfb=True,
353
+ )
354
+ ff_config = cfg.FeedForwardConfig(
355
+ type=cfg.FeedForwardType.GATED,
356
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
357
+ intermediate_size=6 * 1152,
358
+ pre_ff_norm_config=norm_config,
359
+ post_ff_norm_config=norm_config,
360
+ )
361
+
362
+ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
363
+ attn_config = cfg.AttentionConfig(
364
+ num_heads=4,
365
+ head_dim=256,
366
+ num_query_groups=1,
367
+ rotary_base=1_000_000 if (idx + 1) % 6 == 0 else 10_000,
368
+ rotary_percentage=1.0,
369
+ qkv_transpose_before_split=True,
370
+ query_norm_config=norm_config,
371
+ key_norm_config=norm_config,
372
+ logit_softcap=None,
373
+ sliding_window_size=512,
374
+ attn_type=(
375
+ cfg.AttentionType.GLOBAL
376
+ if (idx + 1) % 6 == 0
377
+ else cfg.AttentionType.LOCAL_SLIDING
378
+ ),
379
+ )
380
+ return cfg.TransformerBlockConfig(
381
+ attn_config=attn_config,
382
+ ff_config=ff_config,
383
+ pre_attention_norm_config=norm_config,
384
+ post_attention_norm_config=norm_config,
385
+ )
386
+
387
+ num_layers = 26
388
+ embedding_dim = 1152
389
+ config = cfg.ModelConfig(
390
+ vocab_size=262_144,
391
+ num_layers=num_layers,
392
+ max_seq_len=32_768,
393
+ embedding_dim=embedding_dim,
394
+ embedding_scale=embedding_dim**0.5,
395
+ kv_cache_max_len=kv_cache_max_len,
396
+ block_configs=[get_block_config(i) for i in range(num_layers)],
397
+ final_norm_config=norm_config,
398
+ lm_head_use_bias=False,
399
+ enable_hlfb=True,
400
+ final_logit_softcap=None,
401
+ )
402
+ return config
403
+
404
+
405
+ def get_fake_decoder_config_1b(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
406
+ """Returns a fake model config for a Gemma3 1B model.
407
+
408
+ Args:
409
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
410
+ is 128.
411
+
412
+ Returns:
413
+ A fake model config for a Gemma 1B model.
414
+ """
415
+ config = get_decoder_config_1b(kv_cache_max_len)
416
+ config.vocab_size = 128
417
+ config.num_layers = 2
418
+ config.max_seq_len = 2 * kv_cache_max_len
419
+ config.embedding_dim = 128
420
+ config.embedding_scale = config.embedding_dim**0.5
421
+ config.block_configs = config.block_configs[: config.num_layers]
422
+ for block_config in config.block_configs:
423
+ block_config.attn_config.num_heads = 4
424
+ block_config.attn_config.head_dim = 64
425
+ block_config.attn_config.sliding_window_size = 64
426
+ block_config.ff_config.intermediate_size = 128
427
+ return config
428
+
429
+
430
+ def build_model_1b(checkpoint_path: str, **kwargs) -> nn.Module:
431
+ return model_builder.build_decoder_only_model(
432
+ checkpoint_path=checkpoint_path,
433
+ config=get_decoder_config_1b(**kwargs),
434
+ tensor_names=TENSOR_NAMES,
435
+ model_class=Decoder,
436
+ )