ai-edge-torch-nightly 0.3.0.dev20240823__py3-none-any.whl → 0.3.0.dev20240828__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -17,11 +17,47 @@
17
17
  import ai_edge_torch
18
18
  from ai_edge_torch.testing import model_coverage
19
19
  import torch
20
- import torchvision
20
+ from torch import nn
21
21
 
22
22
  from tensorflow.python.platform import googletest
23
23
 
24
24
 
25
+ class FullyConnectedModel(nn.Module):
26
+ """A simple fully connected model with two fully connected layers."""
27
+
28
+ def __init__(self, input_size, hidden_size, output_size):
29
+ super(FullyConnectedModel, self).__init__()
30
+ self.fc = nn.Linear(input_size, hidden_size) # Fully connected layer
31
+ self.relu = nn.ReLU() # Activation function
32
+ self.output = nn.Linear(hidden_size, output_size)
33
+
34
+ def forward(self, x):
35
+ x = self.fc(x)
36
+ x = self.relu(x)
37
+ x = self.output(x)
38
+ return x
39
+
40
+
41
+ class FullyConvModel(nn.Module):
42
+ """A simple fully convolutional model with two convolutions."""
43
+
44
+ def __init__(self):
45
+ super(FullyConvModel, self).__init__()
46
+ self.conv1 = nn.Conv2d(
47
+ 3, 16, kernel_size=3, padding=1
48
+ ) # Input channels: 3 (RGB), Output channels: 16
49
+ self.relu = nn.ReLU(inplace=True)
50
+ self.conv2 = nn.Conv2d(
51
+ 16, 1, kernel_size=1
52
+ ) # Output channels: 1 (single channel output)
53
+
54
+ def forward(self, x):
55
+ x = self.conv1(x)
56
+ x = self.relu(x)
57
+ x = self.conv2(x)
58
+ return x
59
+
60
+
25
61
  class TestConvertMultiSignature(googletest.TestCase):
26
62
  """Tests conversion of various modules through multi-signature conversion."""
27
63
 
@@ -29,12 +65,12 @@ class TestConvertMultiSignature(googletest.TestCase):
29
65
  super().setUp()
30
66
  torch.manual_seed(0)
31
67
 
32
- def test_convert_mobilenet_v2_with_default(self):
68
+ def test_convert_with_default(self):
33
69
  """Tests conversion of a model with two signatures one of which is the default."""
34
- torch_module = torchvision.models.mobilenet_v2().eval()
70
+ torch_module = FullyConvModel().eval()
35
71
 
36
- args = (torch.randn(4, 3, 224, 224),)
37
- large_args = (torch.randn(4, 3, 336, 336),)
72
+ args = (torch.randn(4, 3, 12, 12),)
73
+ large_args = (torch.randn(4, 3, 24, 24),)
38
74
 
39
75
  signature_name = "large_input"
40
76
 
@@ -51,12 +87,12 @@ class TestConvertMultiSignature(googletest.TestCase):
51
87
  )
52
88
  )
53
89
 
54
- def test_convert_mobilenet_v2_no_default(self):
90
+ def test_convert_no_default(self):
55
91
  """Tests conversion of a model with two signatures none of which is the default."""
56
- torch_module = torchvision.models.mobilenet_v2().eval()
92
+ torch_module = FullyConvModel().eval()
57
93
 
58
- args = (torch.randn(4, 3, 224, 224),)
59
- large_args = (torch.randn(4, 3, 336, 336),)
94
+ args = (torch.randn(4, 3, 12, 12),)
95
+ large_args = (torch.randn(4, 3, 24, 24),)
60
96
 
61
97
  signature_name_1 = "input"
62
98
  signature_name_2 = "large_input"
@@ -84,12 +120,12 @@ class TestConvertMultiSignature(googletest.TestCase):
84
120
  )
85
121
  )
86
122
 
87
- def test_convert_mobilenet_v2_signature_helper(self):
123
+ def test_convert_signature_helper(self):
88
124
  """Tests the ai_edge_torch.signature helper function works."""
89
- torch_module = torchvision.models.mobilenet_v2().eval()
125
+ torch_module = FullyConvModel().eval()
90
126
 
91
- args = (torch.randn(4, 3, 224, 224),)
92
- large_args = (torch.randn(4, 3, 336, 336),)
127
+ args = (torch.randn(4, 3, 12, 12),)
128
+ large_args = (torch.randn(4, 3, 24, 24),)
93
129
 
94
130
  signature_name = "large_input"
95
131
 
@@ -108,39 +144,43 @@ class TestConvertMultiSignature(googletest.TestCase):
108
144
 
109
145
  def test_convert_separate_modules(self):
110
146
  """Tests conversion of two completely different modules as separate signatures."""
111
- mobilentv2 = torchvision.models.mobilenet_v2().eval()
112
- resnet18 = torchvision.models.resnet18().eval()
147
+ fully_conv = FullyConvModel().eval()
148
+ fully_connected = FullyConnectedModel(10, 5, 10).eval()
113
149
 
114
- mobilenet_args = (torch.randn(4, 3, 224, 224),)
115
- resnet_args = (torch.randn(4, 3, 224, 224),)
150
+ fully_conv_args = (torch.randn(4, 3, 12, 12),)
151
+ fully_connected_args = (torch.randn(10),)
116
152
 
117
- mobilenet_signature_name = "mobilentv2"
118
- resnet_signature_name = "resnet18"
153
+ fully_conv_signature_name = "fully_conv"
154
+ fully_connected_signature_name = "fully_connected"
119
155
 
120
156
  edge_model = (
121
157
  ai_edge_torch.signature(
122
- mobilenet_signature_name, mobilentv2, mobilenet_args
158
+ fully_conv_signature_name, fully_conv, fully_conv_args
159
+ )
160
+ .signature(
161
+ fully_connected_signature_name,
162
+ fully_connected,
163
+ fully_connected_args,
123
164
  )
124
- .signature(resnet_signature_name, resnet18, resnet_args)
125
- .convert(resnet18, resnet_args)
165
+ .convert(fully_connected, fully_connected_args)
126
166
  )
127
167
 
128
- mobilenet_inference_args = (torch.randn(4, 3, 224, 224),)
129
- resnet_inference_args = (torch.randn(4, 3, 224, 224),)
168
+ fully_conv_inference_args = (torch.randn(4, 3, 12, 12),)
169
+ fully_connected_inference_args = (torch.randn(10),)
130
170
  self.assertTrue(
131
171
  model_coverage.compare_tflite_torch(
132
172
  edge_model,
133
- mobilentv2,
134
- mobilenet_inference_args,
135
- signature_name=mobilenet_signature_name,
173
+ fully_conv,
174
+ fully_conv_inference_args,
175
+ signature_name=fully_conv_signature_name,
136
176
  )
137
177
  )
138
178
  self.assertTrue(
139
179
  model_coverage.compare_tflite_torch(
140
180
  edge_model,
141
- resnet18,
142
- resnet_inference_args,
143
- signature_name=resnet_signature_name,
181
+ fully_connected,
182
+ fully_connected_inference_args,
183
+ signature_name=fully_connected_signature_name,
144
184
  )
145
185
  )
146
186
 
@@ -15,16 +15,16 @@
15
15
 
16
16
  import argparse
17
17
  import os
18
- from pathlib import Path
19
- from typing import Dict, Optional
18
+ import pathlib
19
+ from typing import Optional
20
20
 
21
- import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
22
- from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
23
- import ai_edge_torch.generative.examples.stable_diffusion.util as util
24
- from ai_edge_torch.model import TfLiteModel
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.stable_diffusion import samplers
23
+ from ai_edge_torch.generative.examples.stable_diffusion import tokenizer
24
+ from ai_edge_torch.generative.examples.stable_diffusion import util
25
25
  import numpy as np
26
26
  from PIL import Image
27
- from tqdm import tqdm
27
+ import tqdm
28
28
 
29
29
  arg_parser = argparse.ArgumentParser()
30
30
  arg_parser.add_argument(
@@ -104,12 +104,12 @@ class StableDiffusion:
104
104
  diffusion_ckpt: str,
105
105
  decoder_ckpt: str
106
106
  ):
107
- self.tokenizer = Tokenizer(tokenizer_vocab_dir)
108
- self.clip = TfLiteModel.load(clip_ckpt)
109
- self.decoder = TfLiteModel.load(decoder_ckpt)
110
- self.diffusion = TfLiteModel.load(diffusion_ckpt)
107
+ self.tokenizer = tokenizer.Tokenizer(tokenizer_vocab_dir)
108
+ self.clip = ai_edge_torch.model.TfLiteModel.load(clip_ckpt)
109
+ self.decoder = ai_edge_torch.model.TfLiteModel.load(decoder_ckpt)
110
+ self.diffusion = ai_edge_torch.model.TfLiteModel.load(diffusion_ckpt)
111
111
  if encoder_ckpt is not None:
112
- self.encoder = TfLiteModel.load(encoder_ckpt)
112
+ self.encoder = ai_edge_torch.model.TfLiteModel.load(encoder_ckpt)
113
113
 
114
114
 
115
115
  def run_tflite_pipeline(
@@ -128,48 +128,32 @@ def run_tflite_pipeline(
128
128
  ):
129
129
  """Run stable diffusion pipeline with tflite model.
130
130
 
131
- model:
132
-
133
- StableDiffsuion model.
134
- prompt:
135
- The prompt to guide the image generation.
136
- output_path:
137
- The path to the generated output image.
138
- uncond_prompt:
139
- The prompt not to guide the image generation.
140
- cfg_scale:
141
- Guidance scale of classifier-free guidance. Higher guidance scale encourages
142
- to generate
143
- images that are closely linked to the text `prompt`, usually at the expense
144
- of lower
145
- image quality.
146
- height:
147
- The height in pixels of the generated image.
148
- width:
149
- The width in pixels of the generated image.
150
- sampler:
151
- A sampler to be used to denoise the encoded image latents. Can be one of
152
- `k_lms, `k_euler`,
153
- or `k_euler_ancestral`.
154
- n_inference_steps:
155
- The number of denoising steps. More denoising steps usually lead to a higher
156
- quality image at the
157
- expense of slower inference. This parameter will be modulated by `strength`.
158
- seed:
159
- A seed to make generation deterministic.
160
- strength:
161
- Conceptually, indicates how much to transform the reference `input_image`.
162
- Must be between 0 and 1.
163
- `input_image` will be used as a starting point, adding more noise to it the
164
- larger the `strength`.
165
- The number of denoising steps depends on the amount of noise initially
166
- added. When `strength` is 1,
167
- added noise will be maximum and the denoising process will run for the full
168
- number of iterations
169
- specified in `n_inference_steps`. A value of 1, therefore, essentially
170
- ignores `input_image`.
171
- input_image:
172
- Image which is served as the starting point for the image generation.
131
+ Args:
132
+ model: StableDiffsuion model.
133
+ prompt: The prompt to guide the image generation.
134
+ output_path: The path to the generated output image.
135
+ uncond_prompt: The prompt not to guide the image generation.
136
+ cfg_scale: Guidance scale of classifier-free guidance. Higher guidance scale
137
+ encourages to generate images that are closely linked to the text
138
+ `prompt`, usually at the expense of lower image quality.
139
+ height: The height in pixels of the generated image.
140
+ width: The width in pixels of the generated image.
141
+ sampler: A sampler to be used to denoise the encoded image latents. Can be
142
+ one of `k_lms, `k_euler`, or `k_euler_ancestral`.
143
+ n_inference_steps: The number of denoising steps. More denoising steps
144
+ usually lead to a higher quality image at the expense of slower inference.
145
+ This parameter will be modulated by `strength`.
146
+ seed: A seed to make generation deterministic.
147
+ strength: Conceptually, indicates how much to transform the reference
148
+ `input_image`. Must be between 0 and 1. `input_image` will be used as a
149
+ starting point, adding more noise to it the larger the `strength`. The
150
+ number of denoising steps depends on the amount of noise initially added.
151
+ When `strength` is 1, added noise will be maximum and the denoising
152
+ process will run for the full number of iterations specified in
153
+ `n_inference_steps`. A value of 1, therefore, essentially ignores
154
+ `input_image`.
155
+ input_image: Image which is served as the starting point for the image
156
+ generation.
173
157
  """
174
158
  if not 0 < strength < 1:
175
159
  raise ValueError('strength must be between 0 and 1')
@@ -202,7 +186,8 @@ def run_tflite_pipeline(
202
186
  context = np.concatenate([cond_context, uncond_context], axis=0)
203
187
  noise_shape = (1, 4, height // 8, width // 8)
204
188
 
205
- # Initialization starts from input_image if any, otherwise, starts from a random sampling.
189
+ # Initialization starts from input_image if any, otherwise, starts from a
190
+ # random sampling.
206
191
  if input_image:
207
192
  if not hasattr(model, 'encoder'):
208
193
  raise AttributeError(
@@ -210,7 +195,6 @@ def run_tflite_pipeline(
210
195
  ' input_image.'
211
196
  )
212
197
  input_image = input_image.resize((width, height))
213
- input_image_np = np.array(input_image).astype(np.float32)
214
198
  input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
215
199
  input_image_np = util.move_channel(input_image_np, to='first')
216
200
  encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
@@ -223,8 +207,8 @@ def run_tflite_pipeline(
223
207
  latents *= sampler.initial_scale
224
208
 
225
209
  # Diffusion process.
226
- timesteps = tqdm(sampler.timesteps)
227
- for i, timestep in enumerate(timesteps):
210
+ timesteps = tqdm.tqdm(sampler.timesteps)
211
+ for _, timestep in enumerate(timesteps):
228
212
  time_embedding = util.get_time_embedding(timestep)
229
213
 
230
214
  input_latents = latents * sampler.get_input_scale()
@@ -242,7 +226,7 @@ def run_tflite_pipeline(
242
226
  images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
243
227
  images = util.move_channel(images, to='last')
244
228
  if not os.path.exists(output_path):
245
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
229
+ pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
246
230
  Image.fromarray(images[0].astype(np.uint8)).save(output_path)
247
231
 
248
232
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240823"
16
+ __version__ = "0.3.0.dev20240828"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240823
3
+ Version: 0.3.0.dev20240828
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -20,7 +20,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
20
  Classifier: Topic :: Software Development
21
21
  Classifier: Topic :: Software Development :: Libraries
22
22
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
- Requires-Python: >=3.9, <3.12
23
+ Requires-Python: >=3.9
24
24
  Description-Content-Type: text/markdown
25
25
  License-File: LICENSE
26
26
  Requires-Dist: numpy
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=7tox6sdFIlCYPLDYpjFcD8cPTSivURCL_VV6-Dt5Sfc,4910
5
- ai_edge_torch/version.py,sha256=a63GrjqX4sRjk0WbC_0gGhT-ax_TLEi4iCEw0Iys7bw,706
5
+ ai_edge_torch/version.py,sha256=FYYHMZbGfkHhNzKfScFWouShgQ9DOVNXZ7WWrrsKjPY,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -28,7 +28,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
29
  ai_edge_torch/_convert/test/test_convert.py,sha256=tvj7fWHHmA9ddtcu-Fp3lJ6emaAQMrtK9wCG0cjgRAo,14413
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=CBiOqq-m7QT2ggBI1jBl9MkTIT5d0nK1tA0BUga0LGs,7994
31
- ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=0xIkahEU26Qx9GGn6Dm05ObIqJvsCdh692dREcaHEdE,4725
31
+ ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=ShZeakqVN9jg9mgFvvMWP0BoPF-u4BTM2hoEkDmLwj0,5780
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=jLAmyHw5llT2ff8qA8mem3eVN57e_o5EpBnW72ZtP2I,3026
33
33
  ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
34
34
  ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
@@ -64,7 +64,7 @@ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7
64
64
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
65
65
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
66
66
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
67
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=Wc94X_kEnbInTAXFgf-VuCvv1A0HxxWrFZ7Tsq3Z8n0,8662
67
+ ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
68
68
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
69
69
  ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=XIXIB0vCvQKOGyIyiZeiIA5DLeSXjkudywvJS4FK7AM,2431
70
70
  ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
@@ -137,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
137
137
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
138
138
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
139
139
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
140
- ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
141
- ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/METADATA,sha256=OPYTq5RQCL2lvIFeBKOwbFusi4rq_Qo2ytgn_JQTVb0,1885
142
- ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
143
- ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
144
- ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/RECORD,,
140
+ ai_edge_torch_nightly-0.3.0.dev20240828.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
141
+ ai_edge_torch_nightly-0.3.0.dev20240828.dist-info/METADATA,sha256=Q5exnNzCT4z8p2mh5MFKl7MgfezYvW0ebW9k4PW8jJo,1878
142
+ ai_edge_torch_nightly-0.3.0.dev20240828.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
143
+ ai_edge_torch_nightly-0.3.0.dev20240828.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
144
+ ai_edge_torch_nightly-0.3.0.dev20240828.dist-info/RECORD,,