ai-edge-torch-nightly 0.3.0.dev20240929__py3-none-any.whl → 0.3.0.dev20241002__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -180,9 +180,13 @@ def run_tflite_pipeline(
180
180
 
181
181
  # Text embedding.
182
182
  cond_tokens = model.tokenizer.encode(prompt)
183
- cond_context = model.clip(np.array(cond_tokens), signature_name='encode')
183
+ cond_context = model.clip(
184
+ np.array(cond_tokens).astype(np.int32), signature_name='encode'
185
+ )
184
186
  uncond_tokens = model.tokenizer.encode(uncond_prompt)
185
- uncond_context = model.clip(np.array(uncond_tokens), signature_name='encode')
187
+ uncond_context = model.clip(
188
+ np.array(uncond_tokens).astype(np.int32), signature_name='encode'
189
+ )
186
190
  context = np.concatenate([cond_context, uncond_context], axis=0)
187
191
  noise_shape = (1, 4, height // 8, width // 8)
188
192
 
@@ -198,7 +202,7 @@ def run_tflite_pipeline(
198
202
  input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
199
203
  input_image_np = util.move_channel(input_image_np, to='first')
200
204
  encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
201
- latents = model.encoder(input_image_np, encoder_noise)
205
+ latents = model.encoder(input_image_np.astype(np.float32), encoder_noise)
202
206
  latents_noise = np.random.normal(size=noise_shape).astype(np.float32)
203
207
  sampler.set_strength(strength=strength)
204
208
  latents += latents_noise * sampler.initial_scale
@@ -214,7 +218,10 @@ def run_tflite_pipeline(
214
218
  input_latents = latents * sampler.get_input_scale()
215
219
  input_latents = input_latents.repeat(2, axis=0)
216
220
  output = model.diffusion(
217
- input_latents, context, time_embedding, signature_name='diffusion'
221
+ input_latents.astype(np.float32),
222
+ context.astype(np.float32),
223
+ time_embedding,
224
+ signature_name='diffusion',
218
225
  )
219
226
  output_cond, output_uncond = np.split(output, 2, axis=0)
220
227
  output = cfg_scale * (output_cond - output_uncond) + output_uncond
@@ -222,7 +229,7 @@ def run_tflite_pipeline(
222
229
  latents = sampler.step(latents, output)
223
230
 
224
231
  # Image decoding.
225
- images = model.decoder(latents, signature_name='decode')
232
+ images = model.decoder(latents.astype(np.float32), signature_name='decode')
226
233
  images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
227
234
  images = util.move_channel(images, to='last')
228
235
  if not os.path.exists(output_path):
@@ -131,9 +131,7 @@ class TestModelConversion(googletest.TestCase):
131
131
  def test_phi3(self):
132
132
  config = phi3.get_fake_model_config()
133
133
  pytorch_model = phi3.Phi3_5Mini(config).eval()
134
- self._test_model(
135
- config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
136
- )
134
+ self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
137
135
 
138
136
  @googletest.skipIf(
139
137
  ai_edge_config.Config.use_torch_xla,
@@ -250,6 +250,7 @@ def merged_bundle_to_tfl_model(
250
250
  },
251
251
  )
252
252
  # Clean up intermediate memory early.
253
+ del tf_functions
253
254
  del tf_module
254
255
  del tf_concrete_funcs
255
256
  gc.collect()
@@ -271,6 +272,8 @@ def merged_bundle_to_tfl_model(
271
272
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
272
273
 
273
274
  tflite_model = converter.convert()
275
+ del converter
276
+ gc.collect()
274
277
 
275
278
  if (
276
279
  quant_config is not None
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240929"
16
+ __version__ = "0.3.0.dev20241002"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240929
3
+ Version: 0.3.0.dev20241002
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=NLOvzXKYuiZP_6pLbpKF-IAcp8M6nFLhL-QqbfqxnEU,706
6
+ ai_edge_torch/version.py,sha256=ODx8CRsxZZYlliSx6vnHxxTorI9c0WPgrVvwGY5KAQI,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -78,7 +78,7 @@ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=v
78
78
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
79
79
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
80
80
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
81
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
81
+ ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=GnY3vPZ-obrWuJifuE5bUooKLqAI7v6q71oaTuLKeBE,8778
82
82
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
83
83
  ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=XIXIB0vCvQKOGyIyiZeiIA5DLeSXjkudywvJS4FK7AM,2431
84
84
  ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
@@ -125,7 +125,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
125
125
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
126
126
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
127
127
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
128
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=SBGHbY8-k7kSEEv-WQQlxGIYtJEVBIbjJPygGdDg9Qg,8921
128
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=ASXTeO9TxjhqcNwXwbyMUP07aqye7wD6JU6OGZCEmR4,8907
129
129
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
130
130
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
131
131
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -147,7 +147,7 @@ ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuT
147
147
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
148
148
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
149
149
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
150
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
150
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
151
151
  ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
152
152
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
153
153
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
@@ -181,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
181
181
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
182
182
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
183
183
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
184
- ai_edge_torch_nightly-0.3.0.dev20240929.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
- ai_edge_torch_nightly-0.3.0.dev20240929.dist-info/METADATA,sha256=BcRgHI-zv6HtkVf7tOnbFLsh_V3QZlheWv_4Qto2ToU,1897
186
- ai_edge_torch_nightly-0.3.0.dev20240929.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
- ai_edge_torch_nightly-0.3.0.dev20240929.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
- ai_edge_torch_nightly-0.3.0.dev20240929.dist-info/RECORD,,
184
+ ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
185
+ ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/METADATA,sha256=l2x0NhvSM0VtobvX6i8hXWKYdfjaRUizk42xaJrQXtw,1897
186
+ ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
187
+ ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
188
+ ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/RECORD,,