ai-edge-torch-nightly 0.4.0.dev20250327__py3-none-any.whl → 0.4.0.dev20250328__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # This module contains custom torch ops for generative.
@@ -14,9 +14,6 @@
14
14
  # ==============================================================================
15
15
  # Common utility functions for data loading etc.
16
16
  from dataclasses import dataclass
17
- import glob
18
- import os
19
- from typing import Sequence
20
17
  from ai_edge_torch.odml_torch import lowerings
21
18
  from jax._src.lib.mlir import ir
22
19
  from jax._src.lib.mlir.dialects import hlo as stablehlo
@@ -31,8 +28,12 @@ def bmm_4d(
31
28
  ) -> torch.Tensor:
32
29
  if not (lhs.ndim == 4 and rhs.ndim == 4):
33
30
  raise ValueError("bmm_4d requires LHS and RHS have rank 4.")
34
- d0_can_bcast = lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
35
- d1_can_bcast = lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
31
+ d0_can_bcast = (
32
+ lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
33
+ )
34
+ d1_can_bcast = (
35
+ lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
36
+ )
36
37
  if not (d0_can_bcast and d1_can_bcast):
37
38
  raise ValueError("bmm_4d requires that dimensions 0 and 1 can broadcast.")
38
39
 
@@ -12,10 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- # Common utility functions for data loading etc.
15
+ # Dynamic update slice op for KV cache update.
16
16
  from dataclasses import dataclass
17
- import glob
18
- import os
19
17
  from typing import Sequence
20
18
  from ai_edge_torch.odml_torch import lowerings
21
19
  from jax._src.lib.mlir import ir
@@ -23,7 +23,7 @@ import functools
23
23
  from typing import Any, List, Tuple, Type
24
24
  from ai_edge_torch.generative.layers import model_config
25
25
  from ai_edge_torch.generative.layers.experimental import types
26
- from ai_edge_torch.generative.utilities import dynamic_update_slice as dus_utils
26
+ from ai_edge_torch.generative.custom_ops import dynamic_update_slice as dus_utils
27
27
  import torch
28
28
  import torch.utils._pytree as pytree
29
29
 
@@ -18,9 +18,9 @@
18
18
  import math
19
19
  from typing import Optional
20
20
 
21
+ from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
21
22
  from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils
22
23
  from ai_edge_torch.generative.layers.experimental import types
23
- from ai_edge_torch.generative.utilities import bmm_4d as bmm_lib
24
24
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
25
25
  from multipledispatch import dispatch
26
26
  import torch
@@ -18,8 +18,8 @@
18
18
  import dataclasses
19
19
  from typing import List, Tuple
20
20
 
21
+ from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
21
22
  from ai_edge_torch.generative.layers import model_config
22
- from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
23
23
  import torch
24
24
  import torch.utils._pytree as pytree
25
25
 
@@ -15,8 +15,6 @@
15
15
 
16
16
  """A suite of tests to validate the Dynamic Update Slice Custom Op."""
17
17
 
18
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
- import ai_edge_torch.generative.layers.model_config as cfg
20
18
  import torch
21
19
  from torch import nn
22
20
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.4.0.dev20250327"
16
+ __version__ = "0.4.0.dev20250328"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.4.0.dev20250327
3
+ Version: 0.4.0.dev20250328
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=_xd77UxPOmGu3uf6thVadog8RZtE7pOm_1fg7TyXnQ8,706
5
+ ai_edge_torch/version.py,sha256=IyhMWqN-g3wNhaYTXhegaL93NTmZkKXsXS6yx4E2kko,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -44,6 +44,9 @@ ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmT
44
44
  ai_edge_torch/fx_infra/graph_utils.py,sha256=nqGe-xIJ77RamSUh0UYyI2XHOsZqFDWax-vpRAtVR_E,2796
45
45
  ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
46
46
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
+ ai_edge_torch/generative/custom_ops/__init__.py,sha256=la5qAQVbq1Iz2K6-YFI-2BRm292SlJ0xgciHMXqF9Wg,727
48
+ ai_edge_torch/generative/custom_ops/bmm_4d.py,sha256=JmVbZCujG_wuBchma8QF3DSBfVca52xYwMV7vAzKOII,2507
49
+ ai_edge_torch/generative/custom_ops/dynamic_update_slice.py,sha256=ZGAq2CfWZsfef5mHulsWmyUx0dDWJX6J6xPjhBrjQdM,2097
47
50
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
51
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
49
52
  ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
@@ -148,7 +151,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDq
148
151
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
149
152
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
150
153
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
151
- ai_edge_torch/generative/layers/kv_cache.py,sha256=jwbt0-2fd_CNWS2fp4nf0zvh6kk5citINGlFC_RtEUU,6540
154
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=zjdovWqgEKtx7cvbA0apOwXaNft5AXxNTbJhBT4CXyg,6541
152
155
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
153
156
  ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
154
157
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
@@ -156,8 +159,8 @@ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIr
156
159
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
157
160
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
158
161
  ai_edge_torch/generative/layers/experimental/attention.py,sha256=95djjlJItDVuSNE3BL0b6u3lQoIhmmdvaik7qBBvQA0,8909
159
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=0H-Rqtm6ArMxchHSv3eeX8W3AryoF73EFEpGNfjciK8,9996
160
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
162
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=uXUxiQjPndXYZVGKgm9FxzHgQDal8GdY7cUZDpc_Sno,9997
163
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFW0iGcZjTuej6VFIkwdSY28fIQi_KTAVdT8gWNmq7o,2880
161
164
  ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
162
165
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
163
166
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
@@ -171,7 +174,7 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
171
174
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
172
175
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
173
176
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
174
- ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
177
+ ai_edge_torch/generative/test/test_custom_dus.py,sha256=ifgnUCWihT59eFdLrlc5_j9sWygEKclU6Iqw6zdlgeI,3177
175
178
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=MBPS-0bDXB0tQSKHa1XwDQeVIfabRbc8JQA99h9fzlQ,5961
176
179
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
177
180
  ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
@@ -180,9 +183,7 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
180
183
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
181
184
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
182
185
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
183
- ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
184
186
  ai_edge_torch/generative/utilities/converter.py,sha256=VtG42CVz657XbvTj-FZJiCFW0Hm11OVKKC_mr2tjxhc,8413
185
- ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
186
187
  ai_edge_torch/generative/utilities/loader.py,sha256=KmbjlKpSJEYaYCy5gxOhiaFj6aVAniaBl-kALv_qsGs,13546
187
188
  ai_edge_torch/generative/utilities/model_builder.py,sha256=eY3qAcBhupIn955YnWuzUi9hoWYvl4ntRWA6PBudzMo,6888
188
189
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
@@ -239,8 +240,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
239
240
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
240
241
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
241
242
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
242
- ai_edge_torch_nightly-0.4.0.dev20250327.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
243
- ai_edge_torch_nightly-0.4.0.dev20250327.dist-info/METADATA,sha256=ndWLikDMYDqNVK_ga3v1vm0trmEDLphlGnUYr9gU8W0,1966
244
- ai_edge_torch_nightly-0.4.0.dev20250327.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
245
- ai_edge_torch_nightly-0.4.0.dev20250327.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
246
- ai_edge_torch_nightly-0.4.0.dev20250327.dist-info/RECORD,,
243
+ ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
244
+ ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/METADATA,sha256=j22JbcB95xcu4aqu-G4NRbh1NxwLi9GpPjJjvpLsaSE,1966
245
+ ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
246
+ ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
247
+ ai_edge_torch_nightly-0.4.0.dev20250328.dist-info/RECORD,,