ai-edge-torch-nightly 0.3.0.dev20241220__py3-none-any.whl → 0.3.0.dev20241221__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -19,6 +19,7 @@ from typing import Optional
19
19
 
20
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
22
23
  from ai_edge_torch.generative.utilities import model_builder
23
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
25
  import torch
@@ -61,8 +62,12 @@ class Decoder(model_builder.DecoderOnlyModel):
61
62
  assert input_embeds is not None
62
63
 
63
64
  repo_pos = input_pos + 1 # PaliGemma position is 1-based.
64
- cos, sin = self.rope_cache
65
- rope = (cos.index_select(0, repo_pos), sin.index_select(0, repo_pos))
65
+ # ROPE parameters for all attn_configs are the same. Take the first one.
66
+ attn_config = self.config.block_config(0).attn_config
67
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
68
+ rope = rotary_pos_emb.build_rope(
69
+ repo_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
70
+ )
66
71
 
67
72
  # The first part of input_embeds are image embeddings. Diagonal causal mask
68
73
  # doesn't work here.
@@ -107,8 +107,6 @@ class DecoderOnlyModel(nn.Module):
107
107
 
108
108
  # token embeddings of shape (b, t, n_embd)
109
109
  input_embeds = self.tok_embedding(tokens)
110
- mask = self.mask_cache.index_select(2, input_pos)
111
- mask = mask[:, :, :, : self.config.kv_cache_max]
112
110
 
113
111
  # ROPE parameters for all attn_configs are the same. Take the first one.
114
112
  attn_config = self.config.block_config(0).attn_config
@@ -117,6 +115,9 @@ class DecoderOnlyModel(nn.Module):
117
115
  input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
118
116
  )
119
117
 
118
+ mask = self.mask_cache.index_select(2, input_pos)
119
+ mask = mask[:, :, :, : self.config.kv_cache_max]
120
+
120
121
  return self.forward_with_embeds(
121
122
  input_embeds, rope, mask, input_pos, kv_cache, export_config
122
123
  )
@@ -198,7 +198,12 @@ class MlirLowered:
198
198
  # build, which may not have the same StableHLO version as what used in
199
199
  # TFLite converter. Therefore we always serialize MLIR module in VHLO.
200
200
  # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
201
- target_version = stablehlo.get_minimum_version()
201
+ if stablehlo.get_api_version() < 9:
202
+ target_version = stablehlo.get_minimum_version()
203
+ else:
204
+ target_version = stablehlo.get_version_from_compatibility_requirement(
205
+ stablehlo.StablehloCompatibilityRequirement.WEEK_4
206
+ )
202
207
  module_bytecode = xla_extension.mlir.serialize_portable_artifact(
203
208
  self.module_bytecode, target_version
204
209
  )
@@ -12,4 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
15
+ from ai_edge_torch.odml_torch.jax_bridge import _wrap
16
+ from ai_edge_torch.odml_torch.jax_bridge import utils
17
+
18
+ wrap = _wrap.wrap
@@ -18,6 +18,7 @@ from . import _convolution
18
18
  from . import _jax_lowerings
19
19
  from . import _layer_norm
20
20
  from . import _quantized_decomposed
21
+ from . import _rand
21
22
  from . import context
22
23
  from . import registry
23
24
  from . import utils
@@ -26,6 +26,7 @@ import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
26
26
 
27
27
  LoweringContext = context.LoweringContext
28
28
 
29
+
29
30
  @functools.cache
30
31
  def _log_usage(op):
31
32
  logging.warning("Use jax lowering: %s", str(op))
@@ -184,8 +185,6 @@ lower_by_torch_xla2(torch.ops.aten.permute_copy)
184
185
  lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
185
186
  lower_by_torch_xla2(torch.ops.aten.pow)
186
187
  lower_by_torch_xla2(torch.ops.aten.prod)
187
- lower_by_torch_xla2(torch.ops.aten.rand)
188
- lower_by_torch_xla2(torch.ops.aten.randn)
189
188
  lower_by_torch_xla2(torch.ops.aten.reciprocal)
190
189
  lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
191
190
  lower_by_torch_xla2(torch.ops.aten.relu)
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ import uuid
16
+
17
+ from ai_edge_torch.odml_torch import export_utils
18
+ from ai_edge_torch.odml_torch.lowerings import context
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from jax._src.lib.mlir import ir
21
+ from jax._src.lib.mlir.dialects import func
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils._pytree as pytree
26
+
27
+ LoweringContext = context.LoweringContext
28
+ lower = registry.lower
29
+
30
+
31
+ def _random_lowering(
32
+ lctx: LoweringContext,
33
+ size: list[int],
34
+ generator,
35
+ dtype: torch.dtype,
36
+ rand_tensor,
37
+ composite_name: str,
38
+ ):
39
+ if dtype is None:
40
+ dtype = torch.float32
41
+
42
+ rand_tensor = rand_tensor.type(dtype)
43
+ data = rand_tensor.detach().numpy()
44
+
45
+ shape, _ = pytree.tree_flatten(size)
46
+ elty = export_utils.torch_dtype_to_ir_element_type(dtype)
47
+
48
+ decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}"
49
+
50
+ with ir.InsertionPoint(lctx.ir_module.body):
51
+
52
+ @func.FuncOp.from_py_func(
53
+ ir.RankedTensorType.get(
54
+ [len(shape)],
55
+ ir.IntegerType.get_signless(32),
56
+ ),
57
+ name=decomp_name,
58
+ )
59
+ def _rand_impl(_):
60
+ return [stablehlo.constant(ir.DenseElementsAttr.get(data))]
61
+
62
+ seed, seed2 = (
63
+ torch.randint(
64
+ torch.iinfo(torch.int64).min,
65
+ torch.iinfo(torch.int64).max,
66
+ (2,),
67
+ dtype=torch.int64,
68
+ generator=generator,
69
+ )
70
+ .detach()
71
+ .numpy()
72
+ )
73
+
74
+ shape_ = stablehlo.constant(
75
+ ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32))
76
+ )
77
+ return stablehlo.CompositeOp(
78
+ result=[ir.RankedTensorType.get(shape, elty)],
79
+ inputs=[shape_],
80
+ name=composite_name,
81
+ composite_attributes=ir.DictAttr.get({
82
+ "seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed),
83
+ "seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2),
84
+ }),
85
+ decomposition=decomp_name,
86
+ ).results[0]
87
+
88
+
89
+ # Schema:
90
+ # - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
91
+ # Device? device=None, bool? pin_memory=None) -> Tensor
92
+ # - aten::rand.generator(SymInt[] size, *, Generator? generator,
93
+ # ScalarType? dtype=None, Layout? layout=None, Device? device=None,
94
+ # bool? pin_memory=None) -> Tensor
95
+ @registry.lower(torch.ops.aten.rand)
96
+ def _aten_rand(
97
+ lctx: LoweringContext,
98
+ size,
99
+ generator=None,
100
+ dtype=None,
101
+ layout=torch.strided,
102
+ device=None,
103
+ pin_memory=False,
104
+ ):
105
+ return _random_lowering(
106
+ lctx,
107
+ size,
108
+ generator,
109
+ dtype,
110
+ rand_tensor=torch.ops.aten.rand.generator(
111
+ size, generator=generator, dtype=dtype
112
+ ),
113
+ composite_name="odml.random_uniform",
114
+ )
115
+
116
+
117
+ # Schema:
118
+ # - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
119
+ # Device? device=None, bool? pin_memory=None) -> Tensor
120
+ # - aten::randn.generator(SymInt[] size, *, Generator? generator,
121
+ # ScalarType? dtype=None, Layout? layout=None, Device? device=None,
122
+ # bool? pin_memory=None) -> Tensor
123
+ @registry.lower(torch.ops.aten.randn)
124
+ def _aten_randn(
125
+ lctx: LoweringContext,
126
+ size,
127
+ generator=None,
128
+ dtype=None,
129
+ layout=torch.strided,
130
+ device=None,
131
+ pin_memory=False,
132
+ ):
133
+ return _random_lowering(
134
+ lctx,
135
+ size,
136
+ generator,
137
+ dtype,
138
+ rand_tensor=torch.ops.aten.randn.generator(
139
+ size, generator=generator, dtype=dtype
140
+ ),
141
+ composite_name="odml.random_standard_normal",
142
+ )
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241220"
16
+ __version__ = "0.3.0.dev20241221"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241220
3
+ Version: 0.3.0.dev20241221
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=xD-MWAEa1ROHhyF3rY7MaL28xsuON0aJwaiXbJ04qfc,706
6
+ ai_edge_torch/version.py,sha256=4pSrONNJgkt6DeTfleRz5DpcHts3SW-iInT2ibr1t9A,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=SzbR16V2JEfkCjjPwRVAFUbFnzu-_1iHPKgGT9Yz7gQ,5678
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -64,7 +64,7 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2u
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
66
  ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=eICKQkJsJuEUkuvn5ymUsI9CGB-oNbgV7VH7BlmklfQ,4961
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=HDDTd4F0kOurhXyqikP5umdY0gVm-FHA1ysaKcz88CM,5261
68
68
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
69
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=nDyI-wUFJSawu57uLbFENei5l4cciqZ8lM5S5beN0FU,5604
70
70
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
@@ -147,7 +147,7 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
147
147
  ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
148
148
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
149
149
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
150
- ai_edge_torch/generative/utilities/model_builder.py,sha256=q82-1E2zYlzpbFW6Vw-MWrJivRXHKpRh8jUxpR-w0sY,6349
150
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=plKHp5csnZpx3GQ1SYTqFpdoaxTVcwXgCmzO5N6ya6I,6350
151
151
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
@@ -169,7 +169,7 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1
169
169
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
170
170
  ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
171
171
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
172
- ai_edge_torch/odml_torch/export.py,sha256=Wc_JM7U2IjZeBmXA6t1AZxREGOWjZ6EB-PIhEevWWeU,13207
172
+ ai_edge_torch/odml_torch/export.py,sha256=QzOPmcNPB7R-KhhPEP0oGVbDRgGPptIxRSoz3S8py9I,13405
173
173
  ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
174
174
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
175
175
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -178,16 +178,17 @@ ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_g
178
178
  ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu2YRyGlMZZqVPWUH6g,762
179
179
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
180
180
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
181
- ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
181
+ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
182
182
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
183
183
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
184
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
184
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GWFl7WWgExLXu6FEYxnig5_g6hd_Sfnl8690uFg2-CU,1013
185
185
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
186
186
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
187
187
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
188
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
188
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
189
189
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
190
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
191
+ ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
191
192
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
192
193
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
193
194
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -200,8 +201,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
201
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
202
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
203
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/METADATA,sha256=PfyYhqbf7VEibw2TEDRb8tBOIPG9dfXhT9tNNou_iZg,1966
205
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241220.dist-info/RECORD,,
204
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/METADATA,sha256=_mQiElLiIpig6KWylK15amdyQP57haDyWH4Xaqqt_Ls,1966
206
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
+ ai_edge_torch_nightly-0.3.0.dev20241221.dist-info/RECORD,,