ai-edge-torch-nightly 0.4.0.dev20250310__py3-none-any.whl → 0.4.0.dev20250312__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,176 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Gemma3 gpu model."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import List, Optional, Tuple
20
+ import xmlrpc
21
+
22
+ from ai_edge_torch.generative.examples.gemma3 import decoder
23
+ from ai_edge_torch.generative.examples.gemma3.cpu_only import image_encoder
24
+ from ai_edge_torch.generative.layers import builder
25
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
+ import ai_edge_torch.generative.layers.model_config as cfg
27
+ from ai_edge_torch.generative.utilities import model_builder
28
+ import ai_edge_torch.generative.utilities.loader as loading_utils
29
+ import torch
30
+ from torch import nn
31
+
32
+
33
+ PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
34
+
35
+
36
+ @dataclass
37
+ class Gemma3MMConfig:
38
+ """Gemma3 model configurations."""
39
+
40
+ image_encoder_config: cfg.ModelConfig
41
+ decoder_config: cfg.ModelConfig
42
+ mm_norm_config: cfg.NormalizationConfig
43
+ mm_extra_tokens: int
44
+ image_token_id: int
45
+ image_projection_scale: float
46
+ image_projection_use_bias: bool = False
47
+
48
+
49
+ class Gemma3MM(nn.Module):
50
+ """A Gemma3 multimodal model built from the Edge Generative API layers."""
51
+
52
+ def __init__(self, config: Gemma3MMConfig):
53
+ super().__init__()
54
+
55
+ self.image_encoder = image_encoder.SiglipVisionEncoderWithExit(
56
+ config.image_encoder_config
57
+ )
58
+ self.decoder = decoder.Decoder(config.decoder_config)
59
+ self.mm_norm = builder.build_norm(
60
+ config.image_encoder_config.embedding_dim,
61
+ config.mm_norm_config,
62
+ )
63
+ self.extra_embedding = nn.Embedding(
64
+ config.mm_extra_tokens, config.image_encoder_config.embedding_dim
65
+ )
66
+ self.image_projection = nn.Linear(
67
+ config.image_encoder_config.embedding_dim,
68
+ config.decoder_config.embedding_dim,
69
+ bias=config.image_projection_use_bias,
70
+ )
71
+ image_embedding_config = config.image_encoder_config.image_embedding
72
+ self.num_patches = (
73
+ image_embedding_config.image_size // image_embedding_config.patch_size
74
+ ) ** 2
75
+ self.config = config
76
+
77
+ @torch.inference_mode
78
+ def forward(
79
+ self,
80
+ tokens: torch.Tensor,
81
+ input_pos: torch.Tensor,
82
+ kv_cache: kv_utils.KVCache,
83
+ image_indices: Optional[torch.Tensor] = None,
84
+ image_feat_indices: Optional[torch.Tensor] = None,
85
+ pixel_values: torch.Tensor = None,
86
+ export_config: Optional[model_builder.ExportConfig] = None,
87
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
88
+ _, seq_len = tokens.size()
89
+ assert self.config.decoder_config.max_seq_len >= seq_len, (
90
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
91
+ f" {self.config.decoder_config.max_seq_len}"
92
+ )
93
+ if pixel_values is None:
94
+ return self.decoder(
95
+ tokens=tokens,
96
+ input_pos=input_pos,
97
+ kv_cache=kv_cache,
98
+ input_embeds=None,
99
+ export_config=export_config,
100
+ )
101
+ vocab_size = self.config.decoder_config.vocab_size
102
+ input_embeds = self.decoder.tok_embedding(
103
+ torch.clip(tokens, 0, vocab_size - 1)
104
+ )
105
+ if self.decoder.config.embedding_scale is not None:
106
+ input_embeds = input_embeds * self.decoder.config.embedding_scale
107
+
108
+ # TODO: Identify embedding path for hard tokens if required.
109
+ # extra_embeds = self.extra_embedding(
110
+ # torch.clip(tokens - vocab_size, 0, self.config.mm_extra_tokens - 1)
111
+ # )
112
+ # extra_embeds = self.image_projection(extra_embeds)
113
+ # input_embeds = torch.where(tokens < self.config.decoder_config.vocab_size,
114
+ # input_embeds, extra_embeds)
115
+ # alternate method of implementation
116
+ # rows, cols = torch.where(tokens >= self.config.vocab_size)
117
+ # ext_embeds = self.ext_embedding(
118
+ # tokens[rows, cols] - self.config.vocab_size
119
+ # )
120
+ # ext_embeds = self.mm_projection(extra_embeds)
121
+ # input_embeds[rows, cols, :] = extra_embeds
122
+
123
+ # Shape of pixel_values: (b, n, c, h, w)
124
+ batch_size, num_media, c, h, w = pixel_values.size()
125
+ pixel_values = pixel_values.view(-1, c, h, w)
126
+ image_encoded = self.image_encoder(pixel_values=pixel_values)
127
+ image_encoded = self.mm_norm(image_encoded)
128
+ image_encoded = self.image_projection(image_encoded)
129
+ _, num_patches, num_channels = image_encoded.size()
130
+ image_encoded = image_encoded.view(
131
+ batch_size, num_media, num_patches, num_channels
132
+ )
133
+
134
+ # Interleave the image soft embeddings with the text embeddings
135
+ for b in range(tokens.shape[0]):
136
+ unbatched_image_encoded = image_encoded[b]
137
+ image_features = unbatched_image_encoded[
138
+ image_indices[b], image_feat_indices[b]
139
+ ]
140
+ index_to_copy = torch.where(image_indices[b] >= 0)[0]
141
+ input_embeds[b] = torch.index_copy(
142
+ input_embeds[b], 0, index_to_copy, image_features[index_to_copy]
143
+ )
144
+ return self.decoder(
145
+ tokens=None,
146
+ input_pos=input_pos,
147
+ kv_cache=kv_cache,
148
+ input_embeds=input_embeds,
149
+ image_indices=image_indices,
150
+ export_config=export_config,
151
+ )
152
+
153
+ def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
154
+ return Gemma3MMConfig(
155
+ image_encoder_config=image_encoder.get_fake_image_encoder_config(),
156
+ decoder_config=decoder.get_fake_decoder_config_4b(**kwargs),
157
+ image_token_id=127,
158
+ image_projection_scale=128**0.5,
159
+ image_projection_use_bias=False,
160
+ mm_norm_config=cfg.NormalizationConfig(
161
+ type=cfg.NormalizationType.LAYER_NORM,
162
+ epsilon=1e-6,
163
+ enable_hlfb=True,
164
+ ),
165
+ mm_extra_tokens=32,
166
+ )
167
+
168
+ def build_model_1b(checkpoint_path: str, **kwargs) -> decoder.Decoder:
169
+ if checkpoint_path:
170
+ model = decoder.build_model_1b(checkpoint_path, **kwargs)
171
+ else:
172
+ config = decoder.get_decoder_config_1b(**kwargs)
173
+ model = decoder.Decoder(config)
174
+ # TODO: Load the parameters of decoder from checkpoint.
175
+ model.eval()
176
+ return model
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
122
122
  hidden_dim, config.normalization_config
123
123
  )
124
124
  self.attention = SelfAttention(
125
- config.attention_batch_size,
126
125
  hidden_dim,
127
126
  config.attention_config,
128
127
  enable_hlfb=config.enable_hlfb,
@@ -178,7 +177,6 @@ class CrossAttentionBlock2D(nn.Module):
178
177
  config.output_dim, config.normalization_config
179
178
  )
180
179
  self.attention = CrossAttention(
181
- config.attention_batch_size,
182
180
  config.query_dim,
183
181
  config.cross_dim,
184
182
  config.hidden_dim,
@@ -305,7 +303,8 @@ class TransformerBlock2D(nn.Module):
305
303
  Args:
306
304
  config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
307
305
  block.
308
- dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
306
+ dim_override: in case specified, overrides
307
+ config.attention_block_config.hidden_dim. Set to None by default.
309
308
  """
310
309
  super().__init__()
311
310
  self.config = config
@@ -41,8 +41,8 @@ class StableHLOCompositeBuilder:
41
41
  self.attr = attr
42
42
  self.name = name
43
43
  self.id = _get_uuid()
44
- self._inputs = []
45
- self._outputs = []
44
+ self._input_cnt = 0
45
+ self._output_cnt = 0
46
46
 
47
47
  def _mark_tensor(self, *tensors: torch.Tensor, is_input: bool):
48
48
  """Mark the input/output tensors of the StableHLO Composite."""
@@ -53,9 +53,20 @@ class StableHLOCompositeBuilder:
53
53
  else None
54
54
  )
55
55
 
56
- for pos, tensor in enumerate(tensors):
56
+ def _pos() -> int:
57
+ if is_input:
58
+ self._input_cnt += 1
59
+ return self._input_cnt - 1
60
+ else:
61
+ self._output_cnt += 1
62
+ return self._output_cnt - 1
63
+
64
+ for tensor in tensors:
57
65
  if not isinstance(tensor, torch.Tensor):
58
66
  raise ValueError(f"input must be a torch tensor. Got {type(tensor)}.")
67
+
68
+ pos = _pos()
69
+
59
70
  marked_tensors.append(
60
71
  mark_tensor.mark_tensor_op(
61
72
  tensor,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250310"
16
+ __version__ = "0.4.0.dev20250312"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250310
3
+ Version: 0.4.0.dev20250312
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=rkRy3DyD2DornHefvbWznzVRjj0A3967fcNBIlVQEz8,706
5
+ ai_edge_torch/version.py,sha256=PlBamGX8JQKmFS2RJl0lWF-mxslb0_eASGwwnezOHuY,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -61,6 +61,15 @@ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=BOQ4zYKMXEX8Adly9-Yt6FB
61
61
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
62
62
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
63
63
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
64
+ ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
65
+ ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=xAjMqhNrSv2srrBvrwCsnbLzdQXVpkZEOYImb3Mvw3w,3910
66
+ ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=_7s_JrzwW4rX07f41VDuRLDZDJDshc3vqhXVY92K8q8,15423
67
+ ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=n2EQVp5SrnMeb0csHrz46_gdNiHTpsApaRmcAc8xyj8,6482
68
+ ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py,sha256=P11xO0F1MUbLMs8ySz6tu6qGDOOyK43q-HV_pqdsCUY,670
69
+ ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py,sha256=4Ym4f8pvHu7dUSkTXfSToNuX8X3fhV5kKuhgEzOcyuw,3012
70
+ ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py,sha256=fB2oYR08u7GcrWYjNbeADRZM5z1vTbE03mHXi497RRw,16140
71
+ ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py,sha256=NeMqW67uQEQl09R7nE3vSpT84KXmAHEg9oy4-7TVC5k,8104
72
+ ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
64
73
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
65
74
  ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
66
75
  ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
@@ -155,7 +164,7 @@ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=0H-Rqtm6ArMxchHS
155
164
  ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
156
165
  ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
157
166
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
158
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
167
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
159
168
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
160
169
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
161
170
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -206,7 +215,7 @@ ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5x
206
215
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
207
216
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
208
217
  ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
209
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
218
+ ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=h6DQkYVS4fnKMALIVdU6Q7J6Ehg3hMCV4C406SyIk3k,3513
210
219
  ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
211
220
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=6Ns2rlfOilLJEk5cUxlkRwm2uxOgEF2-0S2DMcOqr6A,3319
212
221
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
@@ -233,8 +242,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
233
242
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
234
243
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
235
244
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
236
- ai_edge_torch_nightly-0.4.0.dev20250310.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
237
- ai_edge_torch_nightly-0.4.0.dev20250310.dist-info/METADATA,sha256=GXkTBv2DzWbm9p52gwaLQNl3iyKQ8qGK72is7Wy0gMM,1966
238
- ai_edge_torch_nightly-0.4.0.dev20250310.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
239
- ai_edge_torch_nightly-0.4.0.dev20250310.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
240
- ai_edge_torch_nightly-0.4.0.dev20250310.dist-info/RECORD,,
245
+ ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
246
+ ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/METADATA,sha256=DTFxRv9pdU_Uy4wUIU8th5ZpgLzkYVFmM81SgjJAzAo,1966
247
+ ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
248
+ ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
249
+ ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/RECORD,,