ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl → 0.2.0.dev20240601__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +5 -2
- ai_edge_torch/convert/test/test_convert_composites.py +3 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +79 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +107 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +113 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +499 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +67 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/utilities/loader.py +8 -4
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/RECORD +24 -8
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.1.dev202405131930.dist-info → ai_edge_torch_nightly-0.2.0.dev20240601.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Testing weight loader utilities.
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import tempfile
|
|
19
|
+
import unittest
|
|
20
|
+
|
|
21
|
+
import safetensors.torch
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
25
|
+
from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestLoader(unittest.TestCase):
|
|
29
|
+
"""Unit tests that check weight loader."""
|
|
30
|
+
|
|
31
|
+
def test_load_safetensors(self):
|
|
32
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
33
|
+
file_path = os.path.join(temp_dir, "test.safetensors")
|
|
34
|
+
test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
|
|
35
|
+
safetensors.torch.save_file(test_data, file_path)
|
|
36
|
+
|
|
37
|
+
loaded_tensors = loading_utils.load_safetensors(file_path)
|
|
38
|
+
self.assertIn("weight", loaded_tensors)
|
|
39
|
+
self.assertIn("bias", loaded_tensors)
|
|
40
|
+
|
|
41
|
+
def test_load_statedict(self):
|
|
42
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
43
|
+
file_path = os.path.join(temp_dir, "test.pt")
|
|
44
|
+
model = torch.nn.Linear(10, 5)
|
|
45
|
+
state_dict = model.state_dict()
|
|
46
|
+
torch.save(state_dict, file_path)
|
|
47
|
+
|
|
48
|
+
loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
|
|
49
|
+
self.assertIn("weight", loaded_tensors)
|
|
50
|
+
self.assertIn("bias", loaded_tensors)
|
|
51
|
+
|
|
52
|
+
def test_model_loader(self):
|
|
53
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
54
|
+
file_path = os.path.join(temp_dir, "test.safetensors")
|
|
55
|
+
test_weights = {
|
|
56
|
+
"lm_head.weight": torch.randn((32000, 2048)),
|
|
57
|
+
"model.embed_tokens.weight": torch.randn((32000, 2048)),
|
|
58
|
+
"model.layers.0.input_layernorm.weight": torch.randn((2048,)),
|
|
59
|
+
"model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
|
|
60
|
+
"model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
|
|
61
|
+
"model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
|
|
62
|
+
"model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
|
|
63
|
+
"model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
|
|
64
|
+
"model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
|
|
65
|
+
"model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
|
|
66
|
+
"model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
|
|
67
|
+
"model.norm.weight": torch.randn((2048,)),
|
|
68
|
+
}
|
|
69
|
+
safetensors.torch.save_file(test_weights, file_path)
|
|
70
|
+
cfg = tiny_llama.get_model_config()
|
|
71
|
+
cfg.num_layers = 1
|
|
72
|
+
model = tiny_llama.TinyLLamma(cfg)
|
|
73
|
+
|
|
74
|
+
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
|
|
75
|
+
# if returns successfully, it means all the tensors were initiallized.
|
|
76
|
+
loader.load(model, strict=True)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
if __name__ == "__main__":
|
|
80
|
+
unittest.main()
|
|
@@ -228,14 +228,14 @@ class ModelLoader:
|
|
|
228
228
|
q_name = self._names.attn_query_proj.format(idx)
|
|
229
229
|
k_name = self._names.attn_key_proj.format(idx)
|
|
230
230
|
v_name = self._names.attn_value_proj.format(idx)
|
|
231
|
-
converted_state[f"{prefix}.atten_func.
|
|
231
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
|
|
232
232
|
config,
|
|
233
233
|
state.pop(f"{q_name}.weight"),
|
|
234
234
|
state.pop(f"{k_name}.weight"),
|
|
235
235
|
state.pop(f"{v_name}.weight"),
|
|
236
236
|
)
|
|
237
237
|
if config.attn_config.qkv_use_bias:
|
|
238
|
-
converted_state[f"{prefix}.atten_func.
|
|
238
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
|
|
239
239
|
config,
|
|
240
240
|
state.pop(f"{q_name}.bias"),
|
|
241
241
|
state.pop(f"{k_name}.bias"),
|
|
@@ -243,9 +243,13 @@ class ModelLoader:
|
|
|
243
243
|
)
|
|
244
244
|
|
|
245
245
|
o_name = self._names.attn_output_proj.format(idx)
|
|
246
|
-
converted_state[f"{prefix}.atten_func.
|
|
246
|
+
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
247
|
+
f"{o_name}.weight"
|
|
248
|
+
)
|
|
247
249
|
if config.attn_config.output_proj_use_bias:
|
|
248
|
-
converted_state[f"{prefix}.atten_func.
|
|
250
|
+
converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
|
|
251
|
+
f"{o_name}.bias"
|
|
252
|
+
)
|
|
249
253
|
|
|
250
254
|
def _map_norm(
|
|
251
255
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0.dev20240601
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -24,8 +24,8 @@ Requires-Python: >=3.9, <3.12
|
|
|
24
24
|
Description-Content-Type: text/markdown
|
|
25
25
|
License-File: LICENSE
|
|
26
26
|
Requires-Dist: numpy
|
|
27
|
-
Requires-Dist: safetensors
|
|
28
27
|
Requires-Dist: scipy
|
|
28
|
+
Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch ==2.4.*
|
|
31
31
|
|
|
@@ -6,7 +6,7 @@ ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBK
|
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
|
|
7
7
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=Ll2nNwufjcV5nSruQPXiloq7F1E7pWJ2T5clXmy1lk8,2825
|
|
8
8
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
9
|
-
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=
|
|
9
|
+
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=wHVWNNMu5h_ya6GnnJn0cNif9xmdSqr8Vm-R7lllxZM,6213
|
|
10
10
|
ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py,sha256=76XYoIlFDgrzp5QemoaEalPFcEbfszkEH_PLvO1ASCk,2607
|
|
11
11
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
12
12
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
|
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
|
|
23
23
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
24
24
|
ai_edge_torch/convert/test/test_convert.py,sha256=USduDO6PaO3nlA82jMihTct--mCU_ugILZDin00lcJ8,8092
|
|
25
|
-
ai_edge_torch/convert/test/test_convert_composites.py,sha256=
|
|
25
|
+
ai_edge_torch/convert/test/test_convert_composites.py,sha256=SrVn_cEMtQhYYCMOUKK0K7M57MQNQX-lOUwieln0HGA,6616
|
|
26
26
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
27
27
|
ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
|
|
28
28
|
ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
|
|
@@ -38,6 +38,21 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YF4Ua-1lnL3qhQnh1sY5-HlY
|
|
|
38
38
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
39
39
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
|
|
40
40
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
|
|
41
|
+
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
|
+
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=KR1Ci4rlJeeGfsFRliCxUve9K7RTJLZfTRMgFtfQ4MU,2434
|
|
44
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=6REAYy1Bv-Iv5zcmA_m_W6fH6jt5a3IS6Vge18jS_Wo,3633
|
|
45
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=AgVAdUbSkHXONVUjAyBQEXhIUUlinf9kNljcBpWnj3A,3276
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=nq94VpQ103eOimnmdyg7u3Xk1LH1IxGlmIbr2AttRIk,16224
|
|
47
|
+
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=L6hLaMQGb8-_BwSvTLIuDnZwfTqn0K4swBUjfPnYWZo,2341
|
|
48
|
+
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
|
|
49
|
+
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
50
|
+
ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=NFpOfA4KN0JpShm5QvuYbQYZ844NzexWD8nV3WjMOZM,2397
|
|
51
|
+
ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
|
|
52
|
+
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py,sha256=w9C2iVFAn4F2SLJiFdjwR9rRPf5wc3OBS1t0GIOEy08,2310
|
|
53
|
+
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py,sha256=24aIPj6AoK_vSPqmpfmYd-IA8-Uvq6wHLwdVS34Pwtc,2513
|
|
54
|
+
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX9ZSaxwSak2KI44j6TEr_g4pdxS3xpka4u0trjbo,2788
|
|
55
|
+
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
|
|
41
56
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
57
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
43
58
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
|
|
@@ -65,10 +80,11 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=9ktL7fT8C5j1dnY_7
|
|
|
65
80
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2wrf_epILE_7Hx-XfZQ9buk,1798
|
|
66
81
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
|
|
67
82
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
83
|
+
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
68
84
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
|
|
69
85
|
ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
|
|
70
86
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
71
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
87
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=c-ZOIDBVnat_5l2W5sWU7HQm7CL-wducS8poSu5PlUg,10107
|
|
72
88
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
|
|
73
89
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
74
90
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
@@ -84,8 +100,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
|
|
|
84
100
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
85
101
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
86
102
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
87
|
-
ai_edge_torch_nightly-0.
|
|
88
|
-
ai_edge_torch_nightly-0.
|
|
89
|
-
ai_edge_torch_nightly-0.
|
|
90
|
-
ai_edge_torch_nightly-0.
|
|
91
|
-
ai_edge_torch_nightly-0.
|
|
103
|
+
ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
104
|
+
ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/METADATA,sha256=36DXHi7B4r-1hwut2pnUCNat4bVeolghRw2KwpCO3i0,1748
|
|
105
|
+
ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
106
|
+
ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
107
|
+
ai_edge_torch_nightly-0.2.0.dev20240601.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|