ai-edge-torch-nightly 0.3.0.dev20240930__py3-none-any.whl → 0.3.0.dev20241002__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +12 -5
- ai_edge_torch/generative/test/test_model_conversion_large.py +1 -3
- ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241002.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241002.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241002.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241002.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240930.dist-info → ai_edge_torch_nightly-0.3.0.dev20241002.dist-info}/top_level.txt +0 -0
@@ -180,9 +180,13 @@ def run_tflite_pipeline(
|
|
180
180
|
|
181
181
|
# Text embedding.
|
182
182
|
cond_tokens = model.tokenizer.encode(prompt)
|
183
|
-
cond_context = model.clip(
|
183
|
+
cond_context = model.clip(
|
184
|
+
np.array(cond_tokens).astype(np.int32), signature_name='encode'
|
185
|
+
)
|
184
186
|
uncond_tokens = model.tokenizer.encode(uncond_prompt)
|
185
|
-
uncond_context = model.clip(
|
187
|
+
uncond_context = model.clip(
|
188
|
+
np.array(uncond_tokens).astype(np.int32), signature_name='encode'
|
189
|
+
)
|
186
190
|
context = np.concatenate([cond_context, uncond_context], axis=0)
|
187
191
|
noise_shape = (1, 4, height // 8, width // 8)
|
188
192
|
|
@@ -198,7 +202,7 @@ def run_tflite_pipeline(
|
|
198
202
|
input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
|
199
203
|
input_image_np = util.move_channel(input_image_np, to='first')
|
200
204
|
encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
|
201
|
-
latents = model.encoder(input_image_np, encoder_noise)
|
205
|
+
latents = model.encoder(input_image_np.astype(np.float32), encoder_noise)
|
202
206
|
latents_noise = np.random.normal(size=noise_shape).astype(np.float32)
|
203
207
|
sampler.set_strength(strength=strength)
|
204
208
|
latents += latents_noise * sampler.initial_scale
|
@@ -214,7 +218,10 @@ def run_tflite_pipeline(
|
|
214
218
|
input_latents = latents * sampler.get_input_scale()
|
215
219
|
input_latents = input_latents.repeat(2, axis=0)
|
216
220
|
output = model.diffusion(
|
217
|
-
input_latents,
|
221
|
+
input_latents.astype(np.float32),
|
222
|
+
context.astype(np.float32),
|
223
|
+
time_embedding,
|
224
|
+
signature_name='diffusion',
|
218
225
|
)
|
219
226
|
output_cond, output_uncond = np.split(output, 2, axis=0)
|
220
227
|
output = cfg_scale * (output_cond - output_uncond) + output_uncond
|
@@ -222,7 +229,7 @@ def run_tflite_pipeline(
|
|
222
229
|
latents = sampler.step(latents, output)
|
223
230
|
|
224
231
|
# Image decoding.
|
225
|
-
images = model.decoder(latents, signature_name='decode')
|
232
|
+
images = model.decoder(latents.astype(np.float32), signature_name='decode')
|
226
233
|
images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
|
227
234
|
images = util.move_channel(images, to='last')
|
228
235
|
if not os.path.exists(output_path):
|
@@ -131,9 +131,7 @@ class TestModelConversion(googletest.TestCase):
|
|
131
131
|
def test_phi3(self):
|
132
132
|
config = phi3.get_fake_model_config()
|
133
133
|
pytorch_model = phi3.Phi3_5Mini(config).eval()
|
134
|
-
self._test_model(
|
135
|
-
config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
|
136
|
-
)
|
134
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
137
135
|
|
138
136
|
@googletest.skipIf(
|
139
137
|
ai_edge_config.Config.use_torch_xla,
|
@@ -250,6 +250,7 @@ def merged_bundle_to_tfl_model(
|
|
250
250
|
},
|
251
251
|
)
|
252
252
|
# Clean up intermediate memory early.
|
253
|
+
del tf_functions
|
253
254
|
del tf_module
|
254
255
|
del tf_concrete_funcs
|
255
256
|
gc.collect()
|
@@ -271,6 +272,8 @@ def merged_bundle_to_tfl_model(
|
|
271
272
|
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
272
273
|
|
273
274
|
tflite_model = converter.convert()
|
275
|
+
del converter
|
276
|
+
gc.collect()
|
274
277
|
|
275
278
|
if (
|
276
279
|
quant_config is not None
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241002
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=ODx8CRsxZZYlliSx6vnHxxTorI9c0WPgrVvwGY5KAQI,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -78,7 +78,7 @@ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=v
|
|
78
78
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
|
79
79
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
|
80
80
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
81
|
-
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=
|
81
|
+
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=GnY3vPZ-obrWuJifuE5bUooKLqAI7v6q71oaTuLKeBE,8778
|
82
82
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
83
83
|
ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=XIXIB0vCvQKOGyIyiZeiIA5DLeSXjkudywvJS4FK7AM,2431
|
84
84
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
|
@@ -125,7 +125,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
125
125
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
126
126
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
127
127
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
|
128
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
128
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=ASXTeO9TxjhqcNwXwbyMUP07aqye7wD6JU6OGZCEmR4,8907
|
129
129
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
130
130
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
131
131
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -147,7 +147,7 @@ ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuT
|
|
147
147
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
148
148
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
149
149
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
150
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
150
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
|
151
151
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
|
152
152
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
153
153
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
@@ -181,8 +181,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
181
181
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
182
182
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
183
183
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
188
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/METADATA,sha256=l2x0NhvSM0VtobvX6i8hXWKYdfjaRUizk42xaJrQXtw,1897
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
188
|
+
ai_edge_torch_nightly-0.3.0.dev20241002.dist-info/RECORD,,
|
File without changes
|
File without changes
|