ai-edge-torch-nightly 0.4.0.dev20250311__py3-none-any.whl → 0.4.0.dev20250312__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +124 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py +96 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py +463 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py +212 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py +149 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +436 -0
- ai_edge_torch/generative/examples/gemma3/gemma3.py +176 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -3
- ai_edge_torch/odml_torch/composite/mark_tensor.py +0 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/RECORD +17 -8
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250311.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,176 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building a Gemma3 gpu model."""
|
17
|
+
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from typing import List, Optional, Tuple
|
20
|
+
import xmlrpc
|
21
|
+
|
22
|
+
from ai_edge_torch.generative.examples.gemma3 import decoder
|
23
|
+
from ai_edge_torch.generative.examples.gemma3.cpu_only import image_encoder
|
24
|
+
from ai_edge_torch.generative.layers import builder
|
25
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
+
from ai_edge_torch.generative.utilities import model_builder
|
28
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
|
+
import torch
|
30
|
+
from torch import nn
|
31
|
+
|
32
|
+
|
33
|
+
PROJECTION_TENSOR_NAME = "multi_modal_projector.linear"
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class Gemma3MMConfig:
|
38
|
+
"""Gemma3 model configurations."""
|
39
|
+
|
40
|
+
image_encoder_config: cfg.ModelConfig
|
41
|
+
decoder_config: cfg.ModelConfig
|
42
|
+
mm_norm_config: cfg.NormalizationConfig
|
43
|
+
mm_extra_tokens: int
|
44
|
+
image_token_id: int
|
45
|
+
image_projection_scale: float
|
46
|
+
image_projection_use_bias: bool = False
|
47
|
+
|
48
|
+
|
49
|
+
class Gemma3MM(nn.Module):
|
50
|
+
"""A Gemma3 multimodal model built from the Edge Generative API layers."""
|
51
|
+
|
52
|
+
def __init__(self, config: Gemma3MMConfig):
|
53
|
+
super().__init__()
|
54
|
+
|
55
|
+
self.image_encoder = image_encoder.SiglipVisionEncoderWithExit(
|
56
|
+
config.image_encoder_config
|
57
|
+
)
|
58
|
+
self.decoder = decoder.Decoder(config.decoder_config)
|
59
|
+
self.mm_norm = builder.build_norm(
|
60
|
+
config.image_encoder_config.embedding_dim,
|
61
|
+
config.mm_norm_config,
|
62
|
+
)
|
63
|
+
self.extra_embedding = nn.Embedding(
|
64
|
+
config.mm_extra_tokens, config.image_encoder_config.embedding_dim
|
65
|
+
)
|
66
|
+
self.image_projection = nn.Linear(
|
67
|
+
config.image_encoder_config.embedding_dim,
|
68
|
+
config.decoder_config.embedding_dim,
|
69
|
+
bias=config.image_projection_use_bias,
|
70
|
+
)
|
71
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
72
|
+
self.num_patches = (
|
73
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
74
|
+
) ** 2
|
75
|
+
self.config = config
|
76
|
+
|
77
|
+
@torch.inference_mode
|
78
|
+
def forward(
|
79
|
+
self,
|
80
|
+
tokens: torch.Tensor,
|
81
|
+
input_pos: torch.Tensor,
|
82
|
+
kv_cache: kv_utils.KVCache,
|
83
|
+
image_indices: Optional[torch.Tensor] = None,
|
84
|
+
image_feat_indices: Optional[torch.Tensor] = None,
|
85
|
+
pixel_values: torch.Tensor = None,
|
86
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
87
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
88
|
+
_, seq_len = tokens.size()
|
89
|
+
assert self.config.decoder_config.max_seq_len >= seq_len, (
|
90
|
+
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
91
|
+
f" {self.config.decoder_config.max_seq_len}"
|
92
|
+
)
|
93
|
+
if pixel_values is None:
|
94
|
+
return self.decoder(
|
95
|
+
tokens=tokens,
|
96
|
+
input_pos=input_pos,
|
97
|
+
kv_cache=kv_cache,
|
98
|
+
input_embeds=None,
|
99
|
+
export_config=export_config,
|
100
|
+
)
|
101
|
+
vocab_size = self.config.decoder_config.vocab_size
|
102
|
+
input_embeds = self.decoder.tok_embedding(
|
103
|
+
torch.clip(tokens, 0, vocab_size - 1)
|
104
|
+
)
|
105
|
+
if self.decoder.config.embedding_scale is not None:
|
106
|
+
input_embeds = input_embeds * self.decoder.config.embedding_scale
|
107
|
+
|
108
|
+
# TODO: Identify embedding path for hard tokens if required.
|
109
|
+
# extra_embeds = self.extra_embedding(
|
110
|
+
# torch.clip(tokens - vocab_size, 0, self.config.mm_extra_tokens - 1)
|
111
|
+
# )
|
112
|
+
# extra_embeds = self.image_projection(extra_embeds)
|
113
|
+
# input_embeds = torch.where(tokens < self.config.decoder_config.vocab_size,
|
114
|
+
# input_embeds, extra_embeds)
|
115
|
+
# alternate method of implementation
|
116
|
+
# rows, cols = torch.where(tokens >= self.config.vocab_size)
|
117
|
+
# ext_embeds = self.ext_embedding(
|
118
|
+
# tokens[rows, cols] - self.config.vocab_size
|
119
|
+
# )
|
120
|
+
# ext_embeds = self.mm_projection(extra_embeds)
|
121
|
+
# input_embeds[rows, cols, :] = extra_embeds
|
122
|
+
|
123
|
+
# Shape of pixel_values: (b, n, c, h, w)
|
124
|
+
batch_size, num_media, c, h, w = pixel_values.size()
|
125
|
+
pixel_values = pixel_values.view(-1, c, h, w)
|
126
|
+
image_encoded = self.image_encoder(pixel_values=pixel_values)
|
127
|
+
image_encoded = self.mm_norm(image_encoded)
|
128
|
+
image_encoded = self.image_projection(image_encoded)
|
129
|
+
_, num_patches, num_channels = image_encoded.size()
|
130
|
+
image_encoded = image_encoded.view(
|
131
|
+
batch_size, num_media, num_patches, num_channels
|
132
|
+
)
|
133
|
+
|
134
|
+
# Interleave the image soft embeddings with the text embeddings
|
135
|
+
for b in range(tokens.shape[0]):
|
136
|
+
unbatched_image_encoded = image_encoded[b]
|
137
|
+
image_features = unbatched_image_encoded[
|
138
|
+
image_indices[b], image_feat_indices[b]
|
139
|
+
]
|
140
|
+
index_to_copy = torch.where(image_indices[b] >= 0)[0]
|
141
|
+
input_embeds[b] = torch.index_copy(
|
142
|
+
input_embeds[b], 0, index_to_copy, image_features[index_to_copy]
|
143
|
+
)
|
144
|
+
return self.decoder(
|
145
|
+
tokens=None,
|
146
|
+
input_pos=input_pos,
|
147
|
+
kv_cache=kv_cache,
|
148
|
+
input_embeds=input_embeds,
|
149
|
+
image_indices=image_indices,
|
150
|
+
export_config=export_config,
|
151
|
+
)
|
152
|
+
|
153
|
+
def get_fake_model_config(**kwargs) -> Gemma3MMConfig:
|
154
|
+
return Gemma3MMConfig(
|
155
|
+
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
156
|
+
decoder_config=decoder.get_fake_decoder_config_4b(**kwargs),
|
157
|
+
image_token_id=127,
|
158
|
+
image_projection_scale=128**0.5,
|
159
|
+
image_projection_use_bias=False,
|
160
|
+
mm_norm_config=cfg.NormalizationConfig(
|
161
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
162
|
+
epsilon=1e-6,
|
163
|
+
enable_hlfb=True,
|
164
|
+
),
|
165
|
+
mm_extra_tokens=32,
|
166
|
+
)
|
167
|
+
|
168
|
+
def build_model_1b(checkpoint_path: str, **kwargs) -> decoder.Decoder:
|
169
|
+
if checkpoint_path:
|
170
|
+
model = decoder.build_model_1b(checkpoint_path, **kwargs)
|
171
|
+
else:
|
172
|
+
config = decoder.get_decoder_config_1b(**kwargs)
|
173
|
+
model = decoder.Decoder(config)
|
174
|
+
# TODO: Load the parameters of decoder from checkpoint.
|
175
|
+
model.eval()
|
176
|
+
return model
|
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
|
|
122
122
|
hidden_dim, config.normalization_config
|
123
123
|
)
|
124
124
|
self.attention = SelfAttention(
|
125
|
-
config.attention_batch_size,
|
126
125
|
hidden_dim,
|
127
126
|
config.attention_config,
|
128
127
|
enable_hlfb=config.enable_hlfb,
|
@@ -178,7 +177,6 @@ class CrossAttentionBlock2D(nn.Module):
|
|
178
177
|
config.output_dim, config.normalization_config
|
179
178
|
)
|
180
179
|
self.attention = CrossAttention(
|
181
|
-
config.attention_batch_size,
|
182
180
|
config.query_dim,
|
183
181
|
config.cross_dim,
|
184
182
|
config.hidden_dim,
|
@@ -305,7 +303,8 @@ class TransformerBlock2D(nn.Module):
|
|
305
303
|
Args:
|
306
304
|
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
|
307
305
|
block.
|
308
|
-
dim_override: in case specified, overrides
|
306
|
+
dim_override: in case specified, overrides
|
307
|
+
config.attention_block_config.hidden_dim. Set to None by default.
|
309
308
|
"""
|
310
309
|
super().__init__()
|
311
310
|
self.config = config
|
@@ -82,9 +82,6 @@ _torch_library.ODML_TORCH_LIB.define(
|
|
82
82
|
|
83
83
|
mark_tensor_op = torch.ops.odml_torch.mark_tensor.default
|
84
84
|
|
85
|
-
# Prevent composite inputs and outputs from being DCE'd during torch.export.
|
86
|
-
torch.fx.node.has_side_effect(mark_tensor_op)
|
87
|
-
|
88
85
|
|
89
86
|
@torch.library.impl(
|
90
87
|
_torch_library.ODML_TORCH_LIB, "mark_tensor", "CompositeExplicitAutograd"
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.4.0.
|
3
|
+
Version: 0.4.0.dev20250312
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=PlBamGX8JQKmFS2RJl0lWF-mxslb0_eASGwwnezOHuY,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -61,6 +61,15 @@ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=BOQ4zYKMXEX8Adly9-Yt6FB
|
|
61
61
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
62
62
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
63
63
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
64
|
+
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
65
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=xAjMqhNrSv2srrBvrwCsnbLzdQXVpkZEOYImb3Mvw3w,3910
|
66
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=_7s_JrzwW4rX07f41VDuRLDZDJDshc3vqhXVY92K8q8,15423
|
67
|
+
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=n2EQVp5SrnMeb0csHrz46_gdNiHTpsApaRmcAc8xyj8,6482
|
68
|
+
ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py,sha256=P11xO0F1MUbLMs8ySz6tu6qGDOOyK43q-HV_pqdsCUY,670
|
69
|
+
ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py,sha256=4Ym4f8pvHu7dUSkTXfSToNuX8X3fhV5kKuhgEzOcyuw,3012
|
70
|
+
ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py,sha256=fB2oYR08u7GcrWYjNbeADRZM5z1vTbE03mHXi497RRw,16140
|
71
|
+
ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py,sha256=NeMqW67uQEQl09R7nE3vSpT84KXmAHEg9oy4-7TVC5k,8104
|
72
|
+
ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
64
73
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
65
74
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=tMSsqg7LU3LR-PHtKvlWtLCqlk71mfcO9hANU4vnvDM,2734
|
66
75
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
|
@@ -155,7 +164,7 @@ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=0H-Rqtm6ArMxchHS
|
|
155
164
|
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
|
156
165
|
ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
|
157
166
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
158
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
167
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
159
168
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
160
169
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
|
161
170
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -205,7 +214,7 @@ ai_edge_torch/odml_torch/export.py,sha256=7l8R0DEq_vfns8iWpruMlIyaIKZAFzoAy369-7
|
|
205
214
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
206
215
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
207
216
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
208
|
-
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=
|
217
|
+
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
209
218
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=h6DQkYVS4fnKMALIVdU6Q7J6Ehg3hMCV4C406SyIk3k,3513
|
210
219
|
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
|
211
220
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=6Ns2rlfOilLJEk5cUxlkRwm2uxOgEF2-0S2DMcOqr6A,3319
|
@@ -233,8 +242,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
233
242
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
234
243
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
235
244
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
236
|
-
ai_edge_torch_nightly-0.4.0.
|
237
|
-
ai_edge_torch_nightly-0.4.0.
|
238
|
-
ai_edge_torch_nightly-0.4.0.
|
239
|
-
ai_edge_torch_nightly-0.4.0.
|
240
|
-
ai_edge_torch_nightly-0.4.0.
|
245
|
+
ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
246
|
+
ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/METADATA,sha256=DTFxRv9pdU_Uy4wUIU8th5ZpgLzkYVFmM81SgjJAzAo,1966
|
247
|
+
ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
248
|
+
ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
249
|
+
ai_edge_torch_nightly-0.4.0.dev20250312.dist-info/RECORD,,
|
File without changes
|
File without changes
|