ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl → 0.2.0.dev20240601__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (24) hide show
  1. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +5 -2
  2. ai_edge_torch/convert/test/test_convert_composites.py +3 -0
  3. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  4. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/clip.py +79 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +107 -0
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +113 -0
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +499 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +67 -0
  10. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  11. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  12. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  13. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  14. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  15. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  16. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  17. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  18. ai_edge_torch/generative/test/loader_test.py +80 -0
  19. ai_edge_torch/generative/utilities/loader.py +8 -4
  20. {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/METADATA +2 -2
  21. {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/RECORD +24 -8
  22. {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Testing weight loader utilities.
16
+
17
+ import os
18
+ import tempfile
19
+ import unittest
20
+
21
+ import safetensors.torch
22
+ import torch
23
+
24
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
25
+ from ai_edge_torch.generative.utilities import loader as loading_utils
26
+
27
+
28
+ class TestLoader(unittest.TestCase):
29
+ """Unit tests that check weight loader."""
30
+
31
+ def test_load_safetensors(self):
32
+ with tempfile.TemporaryDirectory() as temp_dir:
33
+ file_path = os.path.join(temp_dir, "test.safetensors")
34
+ test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
35
+ safetensors.torch.save_file(test_data, file_path)
36
+
37
+ loaded_tensors = loading_utils.load_safetensors(file_path)
38
+ self.assertIn("weight", loaded_tensors)
39
+ self.assertIn("bias", loaded_tensors)
40
+
41
+ def test_load_statedict(self):
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ file_path = os.path.join(temp_dir, "test.pt")
44
+ model = torch.nn.Linear(10, 5)
45
+ state_dict = model.state_dict()
46
+ torch.save(state_dict, file_path)
47
+
48
+ loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
49
+ self.assertIn("weight", loaded_tensors)
50
+ self.assertIn("bias", loaded_tensors)
51
+
52
+ def test_model_loader(self):
53
+ with tempfile.TemporaryDirectory() as temp_dir:
54
+ file_path = os.path.join(temp_dir, "test.safetensors")
55
+ test_weights = {
56
+ "lm_head.weight": torch.randn((32000, 2048)),
57
+ "model.embed_tokens.weight": torch.randn((32000, 2048)),
58
+ "model.layers.0.input_layernorm.weight": torch.randn((2048,)),
59
+ "model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
60
+ "model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
61
+ "model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
62
+ "model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
63
+ "model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
64
+ "model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
65
+ "model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
66
+ "model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
67
+ "model.norm.weight": torch.randn((2048,)),
68
+ }
69
+ safetensors.torch.save_file(test_weights, file_path)
70
+ cfg = tiny_llama.get_model_config()
71
+ cfg.num_layers = 1
72
+ model = tiny_llama.TinyLLamma(cfg)
73
+
74
+ loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
75
+ # if returns successfully, it means all the tensors were initiallized.
76
+ loader.load(model, strict=True)
77
+
78
+
79
+ if __name__ == "__main__":
80
+ unittest.main()
@@ -228,14 +228,14 @@ class ModelLoader:
228
228
  q_name = self._names.attn_query_proj.format(idx)
229
229
  k_name = self._names.attn_key_proj.format(idx)
230
230
  v_name = self._names.attn_value_proj.format(idx)
231
- converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
231
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
232
232
  config,
233
233
  state.pop(f"{q_name}.weight"),
234
234
  state.pop(f"{k_name}.weight"),
235
235
  state.pop(f"{v_name}.weight"),
236
236
  )
237
237
  if config.attn_config.qkv_use_bias:
238
- converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
238
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
239
239
  config,
240
240
  state.pop(f"{q_name}.bias"),
241
241
  state.pop(f"{k_name}.bias"),
@@ -243,9 +243,13 @@ class ModelLoader:
243
243
  )
244
244
 
245
245
  o_name = self._names.attn_output_proj.format(idx)
246
- converted_state[f"{prefix}.atten_func.proj.weight"] = state.pop(f"{o_name}.weight")
246
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
247
+ f"{o_name}.weight"
248
+ )
247
249
  if config.attn_config.output_proj_use_bias:
248
- converted_state[f"{prefix}.atten_func.proj.bias"] = state.pop(f"{o_name}.bias")
250
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
251
+ f"{o_name}.bias"
252
+ )
249
253
 
250
254
  def _map_norm(
251
255
  self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.1.dev202405131930
3
+ Version: 0.2.0.dev20240601
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -24,8 +24,8 @@ Requires-Python: >=3.9, <3.12
24
24
  Description-Content-Type: text/markdown
25
25
  License-File: LICENSE
26
26
  Requires-Dist: numpy
27
- Requires-Dist: safetensors
28
27
  Requires-Dist: scipy
28
+ Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
30
  Requires-Dist: torch ==2.4.*
31
31
 
@@ -6,7 +6,7 @@ ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBK
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
7
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=Ll2nNwufjcV5nSruQPXiloq7F1E7pWJ2T5clXmy1lk8,2825
8
8
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
9
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=quuPsyRtOeumB4SVRYoj2UmSWfrGzJ6Q2ZqjWeG3UPI,6150
9
+ ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=wHVWNNMu5h_ya6GnnJn0cNif9xmdSqr8Vm-R7lllxZM,6213
10
10
  ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py,sha256=76XYoIlFDgrzp5QemoaEalPFcEbfszkEH_PLvO1ASCk,2607
11
11
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
12
12
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
23
23
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
24
24
  ai_edge_torch/convert/test/test_convert.py,sha256=USduDO6PaO3nlA82jMihTct--mCU_ugILZDin00lcJ8,8092
25
- ai_edge_torch/convert/test/test_convert_composites.py,sha256=gFUa_lKNUfeYMgtulqJvRAtWIvzy3f3eXptMBiJDbms,6403
25
+ ai_edge_torch/convert/test/test_convert_composites.py,sha256=SrVn_cEMtQhYYCMOUKK0K7M57MQNQX-lOUwieln0HGA,6616
26
26
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
27
27
  ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
28
28
  ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
@@ -38,6 +38,21 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YF4Ua-1lnL3qhQnh1sY5-HlY
38
38
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
39
39
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
40
40
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
41
+ ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
+ ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=KR1Ci4rlJeeGfsFRliCxUve9K7RTJLZfTRMgFtfQ4MU,2434
44
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=6REAYy1Bv-Iv5zcmA_m_W6fH6jt5a3IS6Vge18jS_Wo,3633
45
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=AgVAdUbSkHXONVUjAyBQEXhIUUlinf9kNljcBpWnj3A,3276
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=nq94VpQ103eOimnmdyg7u3Xk1LH1IxGlmIbr2AttRIk,16224
47
+ ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=L6hLaMQGb8-_BwSvTLIuDnZwfTqn0K4swBUjfPnYWZo,2341
48
+ ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
+ ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
50
+ ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=NFpOfA4KN0JpShm5QvuYbQYZ844NzexWD8nV3WjMOZM,2397
51
+ ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
52
+ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py,sha256=w9C2iVFAn4F2SLJiFdjwR9rRPf5wc3OBS1t0GIOEy08,2310
53
+ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py,sha256=24aIPj6AoK_vSPqmpfmYd-IA8-Uvq6wHLwdVS34Pwtc,2513
54
+ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX9ZSaxwSak2KI44j6TEr_g4pdxS3xpka4u0trjbo,2788
55
+ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
41
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
43
58
  ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
@@ -65,10 +80,11 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7
65
80
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
66
81
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
67
82
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
+ ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
68
84
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
69
85
  ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
70
86
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
71
- ai_edge_torch/generative/utilities/loader.py,sha256=QrGZ3JlEN_tn8j6EdZOxVt_0u3yB5vBrR3KJtNaAwV8,10029
87
+ ai_edge_torch/generative/utilities/loader.py,sha256=c-ZOIDBVnat_5l2W5sWU7HQm7CL-wducS8poSu5PlUg,10107
72
88
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
73
89
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
74
90
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
@@ -84,8 +100,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
84
100
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
85
101
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
86
102
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
87
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
88
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA,sha256=lQcAb0esNisYUqkzDRHamW4S9luvrJ4QU75042IAqWc,1750
89
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
90
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
91
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD,,
103
+ ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
104
+ ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/METADATA,sha256=36DXHi7B4r-1hwut2pnUCNat4bVeolghRw2KwpCO3i0,1748
105
+ ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
106
+ ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
107
+ ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/RECORD,,