ai-edge-torch-nightly 0.3.0.dev20241219__py3-none-any.whl → 0.3.0.dev20241221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from typing import Optional
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
22
23
  from ai_edge_torch.generative.utilities import model_builder
23
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
25
  import torch
@@ -61,8 +62,12 @@ class Decoder(model_builder.DecoderOnlyModel):
61
62
  assert input_embeds is not None
62
63
 
63
64
  repo_pos = input_pos + 1 # PaliGemma position is 1-based.
64
- cos, sin = self.rope_cache
65
- rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
65
+ # ROPE parameters for all attn_configs are the same. Take the first one.
66
+ attn_config = self.config.block_config(0).attn_config
67
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
68
+ rope = rotary_pos_emb.build_rope(
69
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
70
+ )
66
71
 
67
72
  # The first part of input_embeds are image embeddings. Diagonal causal mask
68
73
  # doesn't work here.
@@ -107,8 +107,6 @@ class DecoderOnlyModel(nn.Module):
107
107
 
108
108
  # token embeddings of shape (b, t, n_embd)
109
109
  input_embeds = self.tok_embedding(tokens)
110
- mask = self.mask_cache.index_select(2, input_pos)
111
- mask = mask[:, :, :, : self.config.kv_cache_max]
112
110
 
113
111
  # ROPE parameters for all attn_configs are the same. Take the first one.
114
112
  attn_config = self.config.block_config(0).attn_config
@@ -117,6 +115,9 @@ class DecoderOnlyModel(nn.Module):
117
115
  input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
118
116
  )
119
117
 
118
+ mask = self.mask_cache.index_select(2, input_pos)
119
+ mask = mask[:, :, :, : self.config.kv_cache_max]
120
+
120
121
  return self.forward_with_embeds(
121
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
122
123
  )
@@ -198,7 +198,12 @@ class MlirLowered:
198
198
  # build, which may not have the same StableHLO version as what used in
199
199
  # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
200
  # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
- target_version = stablehlo.get_minimum_version()
201
+ if stablehlo.get_api_version() < 9:
202
+ target_version = stablehlo.get_minimum_version()
203
+ else:
204
+ target_version = stablehlo.get_version_from_compatibility_requirement(
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_4
206
+ )
202
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
208
  self.module_bytecode, target_version
204
209
  )
@@ -12,4 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
15
+ from ai_edge_torch.odml_torch.jax_bridge import _wrap
16
+ from ai_edge_torch.odml_torch.jax_bridge import utils
17
+
18
+ wrap = _wrap.wrap
@@ -18,6 +18,7 @@ from . import _convolution
18
18
  from . import _jax_lowerings
19
19
  from . import _layer_norm
20
20
  from . import _quantized_decomposed
21
+ from . import _rand
21
22
  from . import context
22
23
  from . import registry
23
24
  from . import utils
@@ -26,6 +26,7 @@ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
26
26
 
27
27
  LoweringContext = context.LoweringContext
28
28
 
29
+
29
30
  @functools.cache
30
31
  def _log_usage(op):
31
32
  logging.warning("Use jax lowering: %s", str(op))
@@ -184,8 +185,6 @@ lower_by_torch_xla2(torch.ops.aten.permute_copy)
184
185
  lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
185
186
  lower_by_torch_xla2(torch.ops.aten.pow)
186
187
  lower_by_torch_xla2(torch.ops.aten.prod)
187
- lower_by_torch_xla2(torch.ops.aten.rand)
188
- lower_by_torch_xla2(torch.ops.aten.randn)
189
188
  lower_by_torch_xla2(torch.ops.aten.reciprocal)
190
189
  lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
191
190
  lower_by_torch_xla2(torch.ops.aten.relu)
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import uuid
16
+
17
+ from ai_edge_torch.odml_torch import export_utils
18
+ from ai_edge_torch.odml_torch.lowerings import context
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import func
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils._pytree as pytree
26
+
27
+ LoweringContext = context.LoweringContext
28
+ lower = registry.lower
29
+
30
+
31
+ def _random_lowering(
32
+ lctx: LoweringContext,
33
+ size: list[int],
34
+ generator,
35
+ dtype: torch.dtype,
36
+ rand_tensor,
37
+ composite_name: str,
38
+ ):
39
+ if dtype is None:
40
+ dtype = torch.float32
41
+
42
+ rand_tensor = rand_tensor.type(dtype)
43
+ data = rand_tensor.detach().numpy()
44
+
45
+ shape, _ = pytree.tree_flatten(size)
46
+ elty = export_utils.torch_dtype_to_ir_element_type(dtype)
47
+
48
+ decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}"
49
+
50
+ with ir.InsertionPoint(lctx.ir_module.body):
51
+
52
+ @func.FuncOp.from_py_func(
53
+ ir.RankedTensorType.get(
54
+ [len(shape)],
55
+ ir.IntegerType.get_signless(32),
56
+ ),
57
+ name=decomp_name,
58
+ )
59
+ def _rand_impl(_):
60
+ return [stablehlo.constant(ir.DenseElementsAttr.get(data))]
61
+
62
+ seed, seed2 = (
63
+ torch.randint(
64
+ torch.iinfo(torch.int64).min,
65
+ torch.iinfo(torch.int64).max,
66
+ (2,),
67
+ dtype=torch.int64,
68
+ generator=generator,
69
+ )
70
+ .detach()
71
+ .numpy()
72
+ )
73
+
74
+ shape_ = stablehlo.constant(
75
+ ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32))
76
+ )
77
+ return stablehlo.CompositeOp(
78
+ result=[ir.RankedTensorType.get(shape, elty)],
79
+ inputs=[shape_],
80
+ name=composite_name,
81
+ composite_attributes=ir.DictAttr.get({
82
+ "seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed),
83
+ "seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2),
84
+ }),
85
+ decomposition=decomp_name,
86
+ ).results[0]
87
+
88
+
89
+ # Schema:
90
+ # - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
91
+ # Device? device=None, bool? pin_memory=None) -> Tensor
92
+ # - aten::rand.generator(SymInt[] size, *, Generator? generator,
93
+ # ScalarType? dtype=None, Layout? layout=None, Device? device=None,
94
+ # bool? pin_memory=None) -> Tensor
95
+ @registry.lower(torch.ops.aten.rand)
96
+ def _aten_rand(
97
+ lctx: LoweringContext,
98
+ size,
99
+ generator=None,
100
+ dtype=None,
101
+ layout=torch.strided,
102
+ device=None,
103
+ pin_memory=False,
104
+ ):
105
+ return _random_lowering(
106
+ lctx,
107
+ size,
108
+ generator,
109
+ dtype,
110
+ rand_tensor=torch.ops.aten.rand.generator(
111
+ size, generator=generator, dtype=dtype
112
+ ),
113
+ composite_name="odml.random_uniform",
114
+ )
115
+
116
+
117
+ # Schema:
118
+ # - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
119
+ # Device? device=None, bool? pin_memory=None) -> Tensor
120
+ # - aten::randn.generator(SymInt[] size, *, Generator? generator,
121
+ # ScalarType? dtype=None, Layout? layout=None, Device? device=None,
122
+ # bool? pin_memory=None) -> Tensor
123
+ @registry.lower(torch.ops.aten.randn)
124
+ def _aten_randn(
125
+ lctx: LoweringContext,
126
+ size,
127
+ generator=None,
128
+ dtype=None,
129
+ layout=torch.strided,
130
+ device=None,
131
+ pin_memory=False,
132
+ ):
133
+ return _random_lowering(
134
+ lctx,
135
+ size,
136
+ generator,
137
+ dtype,
138
+ rand_tensor=torch.ops.aten.randn.generator(
139
+ size, generator=generator, dtype=dtype
140
+ ),
141
+ composite_name="odml.random_standard_normal",
142
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241219"
16
+ __version__ = "0.3.0.dev20241221"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241219
3
+ Version: 0.3.0.dev20241221
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=KLgci0sDiJ2ROCyX7x_9Pkz6EzBHZgmqKHPkXReKe3s,706
6
+ ai_edge_torch/version.py,sha256=4pSrONNJgkt6DeTfleRz5DpcHts3SW-iInT2ibr1t9A,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -64,7 +64,7 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=HDDTd4F0kOurhXyqikP5umdY0gVm-FHA1ysaKcz88CM,5261
68
68
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
69
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
70
70
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
@@ -147,7 +147,7 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
147
147
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
148
148
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
149
149
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
150
- ai_edge_torch/generative/utilities/model_builder.py,sha256=q82-1E2zYlzpbFW6Vw-MWrJivRXHKpRh8jUxpR-w0sY,6349
150
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=plKHp5csnZpx3GQ1SYTqFpdoaxTVcwXgCmzO5N6ya6I,6350
151
151
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -169,7 +169,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
169
169
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
170
170
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
171
171
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
172
- ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
172
+ ai_edge_torch/odml_torch/export.py,sha256=QzOPmcNPB7R-KhhPEP0oGVbDRgGPptIxRSoz3S8py9I,13405
173
173
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
174
174
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
175
175
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -178,16 +178,17 @@ ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_g
178
178
  ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu2YRyGlMZZqVPWUH6g,762
179
179
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
180
180
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
181
- ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
181
+ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
182
182
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
183
183
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
184
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
184
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GWFl7WWgExLXu6FEYxnig5_g6hd_Sfnl8690uFg2-CU,1013
185
185
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
186
186
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
187
187
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
188
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
188
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
189
189
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
190
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
191
+ ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
191
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
192
193
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
193
194
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -200,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
201
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
202
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
203
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241219.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241219.dist-info/METADATA,sha256=3JaZOrMZxk4vVOzoc95KMcXpr3pwvpxIhXdg-_ooijk,1966
205
- ai_edge_torch_nightly-0.3.0.dev20241219.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241219.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241219.dist-info/RECORD,,
204
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/METADATA,sha256=_mQiElLiIpig6KWylK15amdyQP57haDyWH4Xaqqt_Ls,1966
206
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/RECORD,,