ai-edge-torch-nightly 0.5.0.dev20250518__py3-none-any.whl → 0.5.0.dev20250520__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Einsum layer implementation.
16
+
17
+ from typing import Sequence
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ class Einsum(nn.Module):
23
+ """Einsum layer wrapping over torch.einsum."""
24
+
25
+ def __init__(self, shape: Sequence[int], einsum_str: str):
26
+ super().__init__()
27
+ self.shape = shape
28
+ self.einsum_str = einsum_str
29
+ self.w = nn.Parameter(
30
+ torch.empty(shape, dtype=torch.float32),
31
+ requires_grad=False,
32
+ )
33
+ self.einsum_fn = lambda x: torch.einsum(einsum_str, x, self.w)
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """Forward pass of the Einsum layer."""
37
+ return self.einsum_fn(x)
@@ -0,0 +1,31 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from ai_edge_torch.generative.layers import einsum
17
+ import torch
18
+ from absl.testing import absltest as googletest
19
+
20
+
21
+ class FeedForwardTest(googletest.TestCase):
22
+
23
+ def test_einsum(self):
24
+ einsum_layer = einsum.Einsum(shape=(5, 10), einsum_str="btf,fd->btd")
25
+ x = torch.ones((1, 8, 5))
26
+ out = einsum_layer(x)
27
+ self.assertEqual(out.shape, (1, 8, 10))
28
+
29
+
30
+ if __name__ == "__main__":
31
+ googletest.main()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250518"
16
+ __version__ = "0.5.0.dev20250520"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250518
3
+ Version: 0.5.0.dev20250520
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=wxjSFq_rBSxSqbUE8E8EJTCkgvgaRLjq_ZuAM-IZpCU,5606
5
- ai_edge_torch/version.py,sha256=ROs2nnrPNKrl8jrGTynAgRfV8IOrNNSZIEuR176ILB8,706
5
+ ai_edge_torch/version.py,sha256=4yV1q9jK9Zr0i0SQM4PpfioywnIovZWfNODUFlxFS-I,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -171,6 +171,8 @@ ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAe
171
171
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
172
172
  ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
173
173
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
174
+ ai_edge_torch/generative/layers/einsum.py,sha256=EsZSWNVWUs0-1plp4TBnhP4ZhaRDBa2VlDO6hWpUAqU,1288
175
+ ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
174
176
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
175
177
  ai_edge_torch/generative/layers/feed_forward_test.py,sha256=8ZGy79BBpsyS6yKKDEKrDt249G5Mz-8VKWW7_WHx0u4,1655
176
178
  ai_edge_torch/generative/layers/kv_cache.py,sha256=b-7shzDaKexmvQF7P3SiAmIz4ZofjYWv3m5u71GojsA,10460
@@ -262,8 +264,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
262
264
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
263
265
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
264
266
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
265
- ai_edge_torch_nightly-0.5.0.dev20250518.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
266
- ai_edge_torch_nightly-0.5.0.dev20250518.dist-info/METADATA,sha256=G9SZNm9HEGhIlSjLENxNqaA7cIFNWJ83ZN8BMZF9igA,2074
267
- ai_edge_torch_nightly-0.5.0.dev20250518.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
268
- ai_edge_torch_nightly-0.5.0.dev20250518.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
269
- ai_edge_torch_nightly-0.5.0.dev20250518.dist-info/RECORD,,
267
+ ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
268
+ ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/METADATA,sha256=P4YDKSZOCPj-hx7bnU6EWLvigx-dpgHd_cIARQd4Fss,2074
269
+ ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
270
+ ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
271
+ ai_edge_torch_nightly-0.5.0.dev20250520.dist-info/RECORD,,