ai-edge-torch-nightly 0.3.0.dev20240817__py3-none-any.whl → 0.3.0.dev20240822__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/test/test_convert_multisig.py +3 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +1 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +6 -4
- ai_edge_torch/generative/utilities/loader.py +3 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240822.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240822.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240822.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240822.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240822.dist-info}/top_level.txt +0 -0
|
@@ -84,6 +84,9 @@ class TestConvertMultiSignature(googletest.TestCase):
|
|
|
84
84
|
)
|
|
85
85
|
)
|
|
86
86
|
|
|
87
|
+
@googletest.skip(
|
|
88
|
+
reason="Re-enable once the tflite converter issue is fixed.",
|
|
89
|
+
)
|
|
87
90
|
def test_convert_mobilenet_v2_signature_helper(self):
|
|
88
91
|
"""Tests the ai_edge_torch.signature helper function works."""
|
|
89
92
|
torch_module = torchvision.models.mobilenet_v2().eval()
|
|
@@ -203,7 +203,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
203
203
|
final_norm_config=norm_config,
|
|
204
204
|
parallel_residual=False,
|
|
205
205
|
lm_head_use_bias=False,
|
|
206
|
-
enable_hlfb=
|
|
206
|
+
enable_hlfb=True,
|
|
207
207
|
final_logit_softcap=30.0,
|
|
208
208
|
)
|
|
209
209
|
return config
|
|
@@ -242,7 +242,6 @@ def define_and_run_2b() -> None:
|
|
|
242
242
|
out = model.forward(tokens, input_pos)
|
|
243
243
|
out_final = out[0, 8, :]
|
|
244
244
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
|
245
|
-
print(out)
|
|
246
245
|
|
|
247
246
|
|
|
248
247
|
if __name__ == "__main__":
|
|
@@ -99,14 +99,16 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
99
99
|
The output tensor of scaled_dot_product_attention.
|
|
100
100
|
"""
|
|
101
101
|
|
|
102
|
-
if softcap is not None:
|
|
103
|
-
raise NotImplementedError("SDPA with HLFB not available with softcap.")
|
|
104
|
-
|
|
105
102
|
if scale is None:
|
|
106
103
|
scale = 1.0 / math.sqrt(head_size)
|
|
107
104
|
|
|
105
|
+
attrs = {"scale": scale}
|
|
106
|
+
|
|
107
|
+
if softcap is not None:
|
|
108
|
+
attrs["logit_cap"] = softcap
|
|
109
|
+
|
|
108
110
|
builder = StableHLOCompositeBuilder(
|
|
109
|
-
name="odml.scaled_dot_product_attention", attr=
|
|
111
|
+
name="odml.scaled_dot_product_attention", attr=attrs
|
|
110
112
|
)
|
|
111
113
|
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
112
114
|
|
|
@@ -72,7 +72,7 @@ def load_pytorch_statedict(full_path: str):
|
|
|
72
72
|
patterns = []
|
|
73
73
|
if os.path.isdir(full_path):
|
|
74
74
|
patterns.append(os.path.join(full_path, "*.bin"))
|
|
75
|
-
patterns.append(os.path.join(full_path, "
|
|
75
|
+
patterns.append(os.path.join(full_path, "*pt"))
|
|
76
76
|
else:
|
|
77
77
|
patterns.append(full_path)
|
|
78
78
|
for pattern in patterns:
|
|
@@ -149,6 +149,7 @@ class ModelLoader:
|
|
|
149
149
|
enabled.
|
|
150
150
|
"""
|
|
151
151
|
state = self._loader(self._file_name)
|
|
152
|
+
state = state["model_state_dict"] if "model_state_dict" in state else state
|
|
152
153
|
converted_state = dict()
|
|
153
154
|
if self._names.embedding is not None:
|
|
154
155
|
converted_state["tok_embedding.weight"] = state.pop(
|
|
@@ -200,7 +201,7 @@ class ModelLoader:
|
|
|
200
201
|
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
|
|
201
202
|
return load_safetensors
|
|
202
203
|
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
|
|
203
|
-
os.path.join(self._file_name, "
|
|
204
|
+
os.path.join(self._file_name, "*pt")
|
|
204
205
|
):
|
|
205
206
|
return load_pytorch_statedict
|
|
206
207
|
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.3.0.
|
|
3
|
+
Version: 0.3.0.dev20240822
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
5
|
+
ai_edge_torch/version.py,sha256=DjujCBc63P1BCTPwFWV93QP5GPGWMkSuuTWTRIg5YNA,706
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -28,7 +28,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
29
29
|
ai_edge_torch/_convert/test/test_convert.py,sha256=y0ZRivdglGx217rnacze8N6nd7aafk28NkbBFUSa9DQ,13121
|
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=CBiOqq-m7QT2ggBI1jBl9MkTIT5d0nK1tA0BUga0LGs,7994
|
|
31
|
-
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=
|
|
31
|
+
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=cqXJK1YALC1huw87HJVjrdc9xe0LaahRY8tVm_RsKg4,4817
|
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=jLAmyHw5llT2ff8qA8mem3eVN57e_o5EpBnW72ZtP2I,3026
|
|
33
33
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
|
34
34
|
ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
|
|
@@ -53,7 +53,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
|
53
53
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
|
|
54
54
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
|
|
55
55
|
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
|
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
|
56
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=q9Zil66EvRKrSpLVQHxKHu_8NL0HAgY2FbtThoTZVUY,8226
|
|
57
57
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
58
58
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
|
|
59
59
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
|
|
@@ -94,7 +94,7 @@ ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lf
|
|
|
94
94
|
ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
|
|
95
95
|
ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
|
|
96
96
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
|
97
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
|
97
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
|
98
98
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
99
99
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
|
|
100
100
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
|
@@ -114,7 +114,7 @@ ai_edge_torch/generative/test/test_loader.py,sha256=1ZqAq0HY5uIioumsReOVIsbGBx0W
|
|
|
114
114
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=52ciFy_Qol2Xuym6P6EqdL29oai35LSWGvsUwyEdFTo,8477
|
|
115
115
|
ai_edge_torch/generative/test/test_quantize.py,sha256=3SmJm7Kq98gAneU6IGwwJrJYCVH1qwWR6oUxPfb6qiI,5346
|
|
116
116
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
117
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
117
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
|
|
118
118
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
|
119
119
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
|
120
120
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
|
@@ -137,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
|
137
137
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
138
138
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
139
139
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
140
|
-
ai_edge_torch_nightly-0.3.0.
|
|
141
|
-
ai_edge_torch_nightly-0.3.0.
|
|
142
|
-
ai_edge_torch_nightly-0.3.0.
|
|
143
|
-
ai_edge_torch_nightly-0.3.0.
|
|
144
|
-
ai_edge_torch_nightly-0.3.0.
|
|
140
|
+
ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
141
|
+
ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/METADATA,sha256=U8torMOW4U2TfZG7j8mwQ3egNgda43VRAtleRwZmyKw,1885
|
|
142
|
+
ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
143
|
+
ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
144
|
+
ai_edge_torch_nightly-0.3.0.dev20240822.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|