ai-edge-torch-nightly 0.3.0.dev20241120__py3-none-any.whl → 0.3.0.dev20241122__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,7 +39,6 @@ def _func_to_torch_module(func: Callable[..., torch.Tensor]):
39
39
  return TestModule(func).eval()
40
40
 
41
41
 
42
- @googletest.skip('Temporary outage due to changes for b/377531086')
43
42
  class TestConvertComposites(googletest.TestCase):
44
43
  """Tests conversion modules that are meant to be wrapped as composites."""
45
44
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Implementation for Rotary Position embedding. https://arxiv.org/pdf/2104.09864.pdf
16
+ from typing import Tuple
16
17
  import torch
17
18
 
18
19
 
@@ -36,3 +37,52 @@ def apply_rope(
36
37
  rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
37
38
  roped = (x * cos) + (rotated * sin)
38
39
  return roped.transpose(1, 2).type_as(x)
40
+
41
+
42
+ def apply_rope_inline(
43
+ q: torch.Tensor,
44
+ k: torch.Tensor,
45
+ input_pos: torch.Tensor,
46
+ n_elem: int,
47
+ base: int = 10_000,
48
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
49
+ """Computes rotary positional embedding inline for a query and key.
50
+
51
+ Args:
52
+ q: the query tensor.
53
+ k: the key tensor.
54
+ input_pos: the sequence indices for the query and key
55
+ n_elem: number of elements of the head dimension for RoPE computation
56
+
57
+ Returns:
58
+ output the RoPE'd query and key.
59
+ """
60
+
61
+ if n_elem <= 0:
62
+ return q, k
63
+
64
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
65
+ freq_exponents = (2.0 / n_elem) * torch.arange(
66
+ q.shape[-1] // 2, dtype=torch.float32
67
+ )
68
+ timescale = float(base) ** freq_exponents
69
+ radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
70
+ 0
71
+ ).unsqueeze(0)
72
+ cos = torch.cos(radians).type_as(q)
73
+ sin = torch.sin(radians).type_as(q)
74
+
75
+ def apply(x, sin, cos):
76
+ x = x.transpose(1, 2)
77
+ b, h, s, d = x.shape
78
+ ans = torch.split(x, d // 2, dim=-1)
79
+ x1, x2 = ans
80
+ left = x1 * cos - x2 * sin
81
+ right = x2 * cos + x1 * sin
82
+ res = torch.cat([left, right], dim=-1)
83
+ res = res.transpose(1, 2)
84
+ return res
85
+
86
+ q_roped = apply(q, sin, cos)
87
+ k_roped = apply(k, sin, cos)
88
+ return q_roped, k_roped
@@ -0,0 +1,107 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """A suite of tests to validate the Dynamic Update Slice Custom Op."""
17
+
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
+ import ai_edge_torch.generative.layers.model_config as cfg
20
+ import torch
21
+ from torch import nn
22
+
23
+ from absl.testing import absltest as googletest, parameterized
24
+
25
+
26
+ def updated_slice_matches(buffer, update, index):
27
+ indexer = [slice(i, i + d) for i, d in zip(index, update.shape)]
28
+ buf = buffer[indexer]
29
+ return torch.allclose(buf, update)
30
+
31
+
32
+ def intT(x):
33
+ return torch.tensor(x).int()
34
+
35
+
36
+ class DUSMod(nn.Module):
37
+
38
+ def forward(self, buffer, update, index):
39
+ out = dynamic_update_slice(buffer, update, index)
40
+ out = out * 2
41
+ return out
42
+
43
+
44
+ @googletest.skip('Enable this when odml_torch is default b/373387583')
45
+ class TestCustomDUS(parameterized.TestCase):
46
+
47
+ @parameterized.named_parameters(
48
+ (
49
+ 'DUS_whole_buffer',
50
+ torch.randn(1, 1280, 4, 64),
51
+ torch.randn([1, 1024, 4, 64]),
52
+ [intT(0), intT(0), intT(0), intT(0)],
53
+ ),
54
+ (
55
+ 'DUS_kv_example',
56
+ torch.randn(2, 1280, 4, 64),
57
+ torch.randn([2, 1024, 4, 64]),
58
+ [intT(0), intT(0), intT(0), intT(0)],
59
+ ),
60
+ (
61
+ 'DUS_3d',
62
+ torch.randn(2, 256, 4, 64),
63
+ torch.randn([2, 256, 2, 64]),
64
+ [intT(0), intT(0), intT(2), intT(0)],
65
+ ),
66
+ (
67
+ 'DUS_3d_v2',
68
+ torch.randn(2, 256, 4, 64),
69
+ torch.randn([2, 256, 3, 64]),
70
+ [intT(0), intT(0), intT(1), intT(0)],
71
+ ),
72
+ (
73
+ 'DUS_3d_v3',
74
+ torch.randn(6, 8, 32),
75
+ torch.randn([6, 3, 32]),
76
+ [intT(0), intT(5), intT(0)],
77
+ ),
78
+ (
79
+ 'DUS_2d',
80
+ torch.randn(8, 32),
81
+ torch.randn([8, 12]),
82
+ [intT(0), intT(20)],
83
+ ),
84
+ )
85
+ def test_opcheck_dynamic_update_slice(self, buffer, update, indices):
86
+ torch.library.opcheck(dynamic_update_slice, (buffer, update, indices))
87
+ out = dynamic_update_slice(buffer, update, indices)
88
+ self.assertTrue(updated_slice_matches(out, update, indices))
89
+
90
+ def test_exported_program(self):
91
+ buffer = torch.randn(1, 1280, 4, 64)
92
+ update = torch.randn([1, 1024, 4, 64])
93
+ index = [intT(0), intT(0), intT(0), intT(0)]
94
+ dm = DUSMod()
95
+ ep = torch.export.export(dm, (buffer, update, index))
96
+ dus_in_exported_program = False
97
+ for node in ep.graph.nodes:
98
+ if node.op == 'call_function':
99
+ if node.target.__name__.startswith('dynamic_update_slice'):
100
+ dus_in_exported_program = True
101
+ break
102
+
103
+ self.assertTrue(dus_in_exported_program)
104
+
105
+
106
+ if __name__ == '__main__':
107
+ googletest.main()
@@ -0,0 +1,56 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ import glob
18
+ import os
19
+ from typing import Sequence
20
+ from ai_edge_torch.odml_torch import lowerings
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # Use torch.library.custom_op to define a new custom operator.
27
+ # TODO: Update impl for multiple non-trivial start_indices
28
+ @torch.library.custom_op("ai_edge_torch::dynamic_update_slice", mutates_args=())
29
+ def dynamic_update_slice(
30
+ in_tensor: torch.Tensor,
31
+ update: torch.Tensor,
32
+ start_indices: Sequence[torch.Tensor],
33
+ ) -> torch.Tensor:
34
+ compare_size = torch.tensor(in_tensor.size()) == torch.tensor(update.size())
35
+ mismatch = torch.nonzero(~compare_size, as_tuple=False)
36
+ dim = mismatch[0].item() if len(mismatch) > 0 else 0
37
+ start = start_indices[dim].item()
38
+ end = start + update.shape[dim]
39
+ indices = torch.arange(start, end).to(torch.long)
40
+ return in_tensor.index_copy(dim, indices, update)
41
+
42
+
43
+ # Use register_fake to add a ``FakeTensor`` kernel for the operator
44
+ @dynamic_update_slice.register_fake
45
+ def _(in_tensor, update, start_indices):
46
+ return in_tensor.clone().detach()
47
+
48
+
49
+ @lowerings.lower(torch.ops.ai_edge_torch.dynamic_update_slice)
50
+ def _dynamic_update_slice_lower(
51
+ lctx,
52
+ in_tensor: ir.Value,
53
+ update: ir.Value,
54
+ start_indices: Sequence[ir.Value],
55
+ ):
56
+ return stablehlo.dynamic_update_slice(in_tensor, update, start_indices)
@@ -185,6 +185,7 @@ def merged_bundle_to_tfl_model(
185
185
  converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
186
186
  converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
187
187
  converter._experimental_enable_composite_direct_lowering = True
188
+ converter._experimental_enable_dynamic_update_slice = True
188
189
  converter.model_origin_framework = "PYTORCH"
189
190
 
190
191
  conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
@@ -24,6 +24,7 @@ from ai_edge_torch.odml_torch.jax_bridge import utils
24
24
  import jax
25
25
  from jax._src.lib.mlir import ir
26
26
  from jax._src.lib.mlir.dialects import func
27
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
27
28
  import torch.utils._pytree as pytree
28
29
 
29
30
  # Jax double (64bit) precision is required to generate StableHLO mlir with
@@ -143,8 +144,39 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
143
144
  ir_inputs = []
144
145
 
145
146
  results = func.CallOp(cloned_func, ir_inputs).results
147
+
148
+ if lctx.node is None:
149
+ return results[0] if len(results) == 1 else results
150
+
151
+ out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
152
+
153
+ if out_avals is None:
154
+ return results[0] if len(results) == 1 else results
155
+
156
+ def sanitize_result_elty(result, aval):
157
+ # JAX implementation may not respect aten op's output dtype. For example,
158
+ # JAX may implement a slightly different dtype upcast rules, leads to
159
+ # different result's dtype from bridged lowering and torch op output.
160
+ # Here we add an additional `stablehlo.convert` op when dtype does not
161
+ # match, to ensure the lowering's result dtype will always be the same
162
+ # as torch op's output dtype.
163
+ if aval is None:
164
+ return result
165
+
166
+ target_elty = export_utils.torch_dtype_to_ir_element_type(
167
+ lctx.ir_context, aval.dtype
168
+ )
169
+ if result.type.element_type == target_elty:
170
+ return result
171
+ return stablehlo.convert(
172
+ ir.RankedTensorType.get(result.type.shape, target_elty), result
173
+ )
174
+
146
175
  if len(results) == 1:
147
- return results[0]
148
- return results
176
+ return sanitize_result_elty(results[0], out_avals)
177
+ return [
178
+ sanitize_result_elty(result, aval)
179
+ for result, aval in zip(results, out_avals)
180
+ ]
149
181
 
150
182
  return wrapped
@@ -15,13 +15,17 @@
15
15
  import math
16
16
  from typing import Optional, Union
17
17
 
18
+ from ai_edge_torch.odml_torch import export_utils
19
+ from ai_edge_torch.odml_torch.lowerings import context
20
+ from ai_edge_torch.odml_torch.lowerings import registry
18
21
  from ai_edge_torch.odml_torch.lowerings import utils
19
22
  from jax._src.lib.mlir import ir
20
23
  from jax._src.lib.mlir.dialects import hlo as stablehlo
21
24
  import numpy as np
22
25
  import torch
23
26
 
24
- from .registry import lower
27
+ LoweringContext = context.LoweringContext
28
+ lower = registry.lower
25
29
 
26
30
 
27
31
  # add(Tensor self, Tensor other) -> Tensor
@@ -211,6 +215,31 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
211
215
  return stablehlo.floor(x)
212
216
 
213
217
 
218
+ # Schema:
219
+ # - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
220
+ # Torch Reference:
221
+ # - https://pytorch.org/docs/main/generated/torch.cat.html
222
+ @lower(torch.ops.aten.cat.default)
223
+ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
224
+ assert tensors
225
+ non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
226
+ out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
227
+ if not non_empty_tensors:
228
+ return utils.splat(
229
+ 0,
230
+ export_utils.torch_dtype_to_ir_element_type(
231
+ lctx.ir_context, out_aval.dtype
232
+ ),
233
+ out_aval.shape,
234
+ )
235
+
236
+ if dim < 0:
237
+ dim = dim + len(out_aval.shape)
238
+ dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)
239
+
240
+ return stablehlo.concatenate(non_empty_tensors, dim)
241
+
242
+
214
243
  # Schema:
215
244
  # - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
216
245
  # start=None, SymInt? end=None, SymInt step=1) -> Tensor
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.bitwise_not)
105
105
  lower_by_torch_xla2(torch.ops.aten.bitwise_or)
106
106
  lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
107
107
  lower_by_torch_xla2(torch.ops.aten.bmm)
108
- lower_by_torch_xla2(torch.ops.aten.cat)
109
108
  lower_by_torch_xla2(torch.ops.aten.ceil)
110
109
  lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
111
110
  lower_by_torch_xla2(torch.ops.aten.clamp.default)
@@ -172,7 +171,6 @@ lower_by_torch_xla2(torch.ops.aten.mm)
172
171
  lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
173
172
  lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
174
173
  lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
175
- lower_by_torch_xla2(torch.ops.aten.native_group_norm)
176
174
  lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
177
175
  lower_by_torch_xla2(torch.ops.aten.ne)
178
176
  lower_by_torch_xla2(torch.ops.aten.neg)
@@ -61,6 +61,7 @@ global_registry.decompositions.update(
61
61
  torch.ops.aten._adaptive_avg_pool2d,
62
62
  torch.ops.aten._adaptive_avg_pool3d,
63
63
  torch.ops.aten.grid_sampler_2d,
64
+ torch.ops.aten.native_group_norm,
64
65
  torch.ops.aten.native_dropout,
65
66
  torch.ops.aten.reflection_pad1d,
66
67
  torch.ops.aten.reflection_pad2d,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241120"
16
+ __version__ = "0.3.0.dev20241122"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241120
3
+ Version: 0.3.0.dev20241122
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
29
29
  Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
- Requires-Dist: tf-nightly>=2.19.0.dev20241001
32
+ Requires-Dist: tf-nightly>=2.19.0.dev20241121
33
33
  Requires-Dist: ai-edge-litert-nightly
34
34
  Requires-Dist: ai-edge-quantizer-nightly
35
35
 
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=52sF7t2CBQE8RcB2Hcmo-f6_BLyCW9NzWZ-wTKM9ho4,706
6
+ ai_edge_torch/version.py,sha256=B4r6opjqsPmDJdLbwvWto6dM-0KbsjszxSL6CXmi8K8,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
29
  ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
30
- ai_edge_torch/_convert/test/test_convert_composites.py,sha256=ELwHxTdTTCJm30aWg_PZXxg9HvDM4Hnf9lT0wwOWT6s,8060
30
+ ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
33
33
  ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
@@ -120,7 +120,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
120
120
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
121
121
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
122
122
  ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
123
- ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
123
+ ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
124
124
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
125
125
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
126
126
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
@@ -134,6 +134,7 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
134
134
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
135
135
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
136
136
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
137
+ ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
137
138
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
138
139
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
139
140
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
@@ -142,6 +143,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
142
143
  ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
143
144
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
144
145
  ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
146
+ ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
145
147
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
146
148
  ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
147
149
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
@@ -158,7 +160,7 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
158
160
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
159
161
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
160
162
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
161
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
163
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
162
164
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
163
165
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
164
166
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
@@ -175,16 +177,16 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
175
177
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
176
178
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
177
179
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
178
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
180
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=oQo9nxH08NnEDeZaGoCUk1kRtoEOM_f0DUOyd9nfxjg,6673
179
181
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
180
182
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
181
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=z_hPJX9n97d6obcsS9OHXpKqbmw6QqACXgnq5ML6Rhs,9014
183
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=eH9eJqFO-BI9l4WdXfjsItODPRa18SAR_qSvJ6-7gxc,9987
182
184
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
183
185
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
184
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=1JeX3j7Rt3KE7Z2eYRrhtcYgO3EKnRyZFKAUWXw-bsU,10812
186
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
185
187
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
186
188
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
187
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
189
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
188
190
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
189
191
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
190
192
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -194,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
194
196
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
195
197
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
196
198
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
197
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
198
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/METADATA,sha256=1Nv_QeerPRw888sOTf4jHx5Ihu-PJD9rL8GOpRHSTa4,1897
199
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
200
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
201
- ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/RECORD,,
199
+ ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
+ ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/METADATA,sha256=-YpC-ksRKR8hJ8pZET4Q2F5KbUiRmGOXPhBoEQgIuOA,1897
201
+ ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
202
+ ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
+ ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/RECORD,,