ai-edge-torch-nightly 0.3.0.dev20241121__py3-none-any.whl → 0.3.0.dev20241122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert_composites.py +0 -1
- ai_edge_torch/generative/layers/rotary_position_embedding.py +50 -0
- ai_edge_torch/generative/test/test_custom_dus.py +107 -0
- ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +1 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +34 -2
- ai_edge_torch/odml_torch/lowerings/_basic.py +4 -4
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241121.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20241121.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/RECORD +15 -13
- {ai_edge_torch_nightly-0.3.0.dev20241121.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241121.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241121.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,6 @@ def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
39
39
|
return TestModule(func).eval()
|
40
40
|
|
41
41
|
|
42
|
-
@googletest.skip('Temporary outage due to changes for b/377531086')
|
43
42
|
class TestConvertComposites(googletest.TestCase):
|
44
43
|
"""Tests conversion modules that are meant to be wrapped as composites."""
|
45
44
|
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
# Implementation for Rotary Position embedding. https://arxiv.org/pdf/2104.09864.pdf
|
16
|
+
from typing import Tuple
|
16
17
|
import torch
|
17
18
|
|
18
19
|
|
@@ -36,3 +37,52 @@ def apply_rope(
|
|
36
37
|
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
37
38
|
roped = (x * cos) + (rotated * sin)
|
38
39
|
return roped.transpose(1, 2).type_as(x)
|
40
|
+
|
41
|
+
|
42
|
+
def apply_rope_inline(
|
43
|
+
q: torch.Tensor,
|
44
|
+
k: torch.Tensor,
|
45
|
+
input_pos: torch.Tensor,
|
46
|
+
n_elem: int,
|
47
|
+
base: int = 10_000,
|
48
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
+
"""Computes rotary positional embedding inline for a query and key.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
q: the query tensor.
|
53
|
+
k: the key tensor.
|
54
|
+
input_pos: the sequence indices for the query and key
|
55
|
+
n_elem: number of elements of the head dimension for RoPE computation
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
output the RoPE'd query and key.
|
59
|
+
"""
|
60
|
+
|
61
|
+
if n_elem <= 0:
|
62
|
+
return q, k
|
63
|
+
|
64
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
|
+
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
+
q.shape[-1] // 2, dtype=torch.float32
|
67
|
+
)
|
68
|
+
timescale = float(base) ** freq_exponents
|
69
|
+
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
|
+
0
|
71
|
+
).unsqueeze(0)
|
72
|
+
cos = torch.cos(radians).type_as(q)
|
73
|
+
sin = torch.sin(radians).type_as(q)
|
74
|
+
|
75
|
+
def apply(x, sin, cos):
|
76
|
+
x = x.transpose(1, 2)
|
77
|
+
b, h, s, d = x.shape
|
78
|
+
ans = torch.split(x, d // 2, dim=-1)
|
79
|
+
x1, x2 = ans
|
80
|
+
left = x1 * cos - x2 * sin
|
81
|
+
right = x2 * cos + x1 * sin
|
82
|
+
res = torch.cat([left, right], dim=-1)
|
83
|
+
res = res.transpose(1, 2)
|
84
|
+
return res
|
85
|
+
|
86
|
+
q_roped = apply(q, sin, cos)
|
87
|
+
k_roped = apply(k, sin, cos)
|
88
|
+
return q_roped, k_roped
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""A suite of tests to validate the Dynamic Update Slice Custom Op."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import torch
|
21
|
+
from torch import nn
|
22
|
+
|
23
|
+
from absl.testing import absltest as googletest, parameterized
|
24
|
+
|
25
|
+
|
26
|
+
def updated_slice_matches(buffer, update, index):
|
27
|
+
indexer = [slice(i, i + d) for i, d in zip(index, update.shape)]
|
28
|
+
buf = buffer[indexer]
|
29
|
+
return torch.allclose(buf, update)
|
30
|
+
|
31
|
+
|
32
|
+
def intT(x):
|
33
|
+
return torch.tensor(x).int()
|
34
|
+
|
35
|
+
|
36
|
+
class DUSMod(nn.Module):
|
37
|
+
|
38
|
+
def forward(self, buffer, update, index):
|
39
|
+
out = dynamic_update_slice(buffer, update, index)
|
40
|
+
out = out * 2
|
41
|
+
return out
|
42
|
+
|
43
|
+
|
44
|
+
@googletest.skip('Enable this when odml_torch is default b/373387583')
|
45
|
+
class TestCustomDUS(parameterized.TestCase):
|
46
|
+
|
47
|
+
@parameterized.named_parameters(
|
48
|
+
(
|
49
|
+
'DUS_whole_buffer',
|
50
|
+
torch.randn(1, 1280, 4, 64),
|
51
|
+
torch.randn([1, 1024, 4, 64]),
|
52
|
+
[intT(0), intT(0), intT(0), intT(0)],
|
53
|
+
),
|
54
|
+
(
|
55
|
+
'DUS_kv_example',
|
56
|
+
torch.randn(2, 1280, 4, 64),
|
57
|
+
torch.randn([2, 1024, 4, 64]),
|
58
|
+
[intT(0), intT(0), intT(0), intT(0)],
|
59
|
+
),
|
60
|
+
(
|
61
|
+
'DUS_3d',
|
62
|
+
torch.randn(2, 256, 4, 64),
|
63
|
+
torch.randn([2, 256, 2, 64]),
|
64
|
+
[intT(0), intT(0), intT(2), intT(0)],
|
65
|
+
),
|
66
|
+
(
|
67
|
+
'DUS_3d_v2',
|
68
|
+
torch.randn(2, 256, 4, 64),
|
69
|
+
torch.randn([2, 256, 3, 64]),
|
70
|
+
[intT(0), intT(0), intT(1), intT(0)],
|
71
|
+
),
|
72
|
+
(
|
73
|
+
'DUS_3d_v3',
|
74
|
+
torch.randn(6, 8, 32),
|
75
|
+
torch.randn([6, 3, 32]),
|
76
|
+
[intT(0), intT(5), intT(0)],
|
77
|
+
),
|
78
|
+
(
|
79
|
+
'DUS_2d',
|
80
|
+
torch.randn(8, 32),
|
81
|
+
torch.randn([8, 12]),
|
82
|
+
[intT(0), intT(20)],
|
83
|
+
),
|
84
|
+
)
|
85
|
+
def test_opcheck_dynamic_update_slice(self, buffer, update, indices):
|
86
|
+
torch.library.opcheck(dynamic_update_slice, (buffer, update, indices))
|
87
|
+
out = dynamic_update_slice(buffer, update, indices)
|
88
|
+
self.assertTrue(updated_slice_matches(out, update, indices))
|
89
|
+
|
90
|
+
def test_exported_program(self):
|
91
|
+
buffer = torch.randn(1, 1280, 4, 64)
|
92
|
+
update = torch.randn([1, 1024, 4, 64])
|
93
|
+
index = [intT(0), intT(0), intT(0), intT(0)]
|
94
|
+
dm = DUSMod()
|
95
|
+
ep = torch.export.export(dm, (buffer, update, index))
|
96
|
+
dus_in_exported_program = False
|
97
|
+
for node in ep.graph.nodes:
|
98
|
+
if node.op == 'call_function':
|
99
|
+
if node.target.__name__.startswith('dynamic_update_slice'):
|
100
|
+
dus_in_exported_program = True
|
101
|
+
break
|
102
|
+
|
103
|
+
self.assertTrue(dus_in_exported_program)
|
104
|
+
|
105
|
+
|
106
|
+
if __name__ == '__main__':
|
107
|
+
googletest.main()
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import glob
|
18
|
+
import os
|
19
|
+
from typing import Sequence
|
20
|
+
from ai_edge_torch.odml_torch import lowerings
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# Use torch.library.custom_op to define a new custom operator.
|
27
|
+
# TODO: Update impl for multiple non-trivial start_indices
|
28
|
+
@torch.library.custom_op("ai_edge_torch::dynamic_update_slice", mutates_args=())
|
29
|
+
def dynamic_update_slice(
|
30
|
+
in_tensor: torch.Tensor,
|
31
|
+
update: torch.Tensor,
|
32
|
+
start_indices: Sequence[torch.Tensor],
|
33
|
+
) -> torch.Tensor:
|
34
|
+
compare_size = torch.tensor(in_tensor.size()) == torch.tensor(update.size())
|
35
|
+
mismatch = torch.nonzero(~compare_size, as_tuple=False)
|
36
|
+
dim = mismatch[0].item() if len(mismatch) > 0 else 0
|
37
|
+
start = start_indices[dim].item()
|
38
|
+
end = start + update.shape[dim]
|
39
|
+
indices = torch.arange(start, end).to(torch.long)
|
40
|
+
return in_tensor.index_copy(dim, indices, update)
|
41
|
+
|
42
|
+
|
43
|
+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
|
44
|
+
@dynamic_update_slice.register_fake
|
45
|
+
def _(in_tensor, update, start_indices):
|
46
|
+
return in_tensor.clone().detach()
|
47
|
+
|
48
|
+
|
49
|
+
@lowerings.lower(torch.ops.ai_edge_torch.dynamic_update_slice)
|
50
|
+
def _dynamic_update_slice_lower(
|
51
|
+
lctx,
|
52
|
+
in_tensor: ir.Value,
|
53
|
+
update: ir.Value,
|
54
|
+
start_indices: Sequence[ir.Value],
|
55
|
+
):
|
56
|
+
return stablehlo.dynamic_update_slice(in_tensor, update, start_indices)
|
@@ -185,6 +185,7 @@ def merged_bundle_to_tfl_model(
|
|
185
185
|
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
186
186
|
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
187
187
|
converter._experimental_enable_composite_direct_lowering = True
|
188
|
+
converter._experimental_enable_dynamic_update_slice = True
|
188
189
|
converter.model_origin_framework = "PYTORCH"
|
189
190
|
|
190
191
|
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
@@ -24,6 +24,7 @@ from ai_edge_torch.odml_torch.jax_bridge import utils
|
|
24
24
|
import jax
|
25
25
|
from jax._src.lib.mlir import ir
|
26
26
|
from jax._src.lib.mlir.dialects import func
|
27
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
27
28
|
import torch.utils._pytree as pytree
|
28
29
|
|
29
30
|
# Jax double (64bit) precision is required to generate StableHLO mlir with
|
@@ -143,8 +144,39 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
143
144
|
ir_inputs = []
|
144
145
|
|
145
146
|
results = func.CallOp(cloned_func, ir_inputs).results
|
147
|
+
|
148
|
+
if lctx.node is None:
|
149
|
+
return results[0] if len(results) == 1 else results
|
150
|
+
|
151
|
+
out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
|
152
|
+
|
153
|
+
if out_avals is None:
|
154
|
+
return results[0] if len(results) == 1 else results
|
155
|
+
|
156
|
+
def sanitize_result_elty(result, aval):
|
157
|
+
# JAX implementation may not respect aten op's output dtype. For example,
|
158
|
+
# JAX may implement a slightly different dtype upcast rules, leads to
|
159
|
+
# different result's dtype from bridged lowering and torch op output.
|
160
|
+
# Here we add an additional `stablehlo.convert` op when dtype does not
|
161
|
+
# match, to ensure the lowering's result dtype will always be the same
|
162
|
+
# as torch op's output dtype.
|
163
|
+
if aval is None:
|
164
|
+
return result
|
165
|
+
|
166
|
+
target_elty = export_utils.torch_dtype_to_ir_element_type(
|
167
|
+
lctx.ir_context, aval.dtype
|
168
|
+
)
|
169
|
+
if result.type.element_type == target_elty:
|
170
|
+
return result
|
171
|
+
return stablehlo.convert(
|
172
|
+
ir.RankedTensorType.get(result.type.shape, target_elty), result
|
173
|
+
)
|
174
|
+
|
146
175
|
if len(results) == 1:
|
147
|
-
return results[0]
|
148
|
-
return
|
176
|
+
return sanitize_result_elty(results[0], out_avals)
|
177
|
+
return [
|
178
|
+
sanitize_result_elty(result, aval)
|
179
|
+
for result, aval in zip(results, out_avals)
|
180
|
+
]
|
149
181
|
|
150
182
|
return wrapped
|
@@ -223,18 +223,18 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
|
|
223
223
|
def _aten_cat(lctx: LoweringContext, tensors, dim=0):
|
224
224
|
assert tensors
|
225
225
|
non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
|
226
|
-
|
226
|
+
out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
|
227
227
|
if not non_empty_tensors:
|
228
228
|
return utils.splat(
|
229
229
|
0,
|
230
230
|
export_utils.torch_dtype_to_ir_element_type(
|
231
|
-
lctx.ir_context,
|
231
|
+
lctx.ir_context, out_aval.dtype
|
232
232
|
),
|
233
|
-
|
233
|
+
out_aval.shape,
|
234
234
|
)
|
235
235
|
|
236
236
|
if dim < 0:
|
237
|
-
dim = dim + len(
|
237
|
+
dim = dim + len(out_aval.shape)
|
238
238
|
dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)
|
239
239
|
|
240
240
|
return stablehlo.concatenate(non_empty_tensors, dim)
|
@@ -171,7 +171,6 @@ lower_by_torch_xla2(torch.ops.aten.mm)
|
|
171
171
|
lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
172
172
|
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
173
173
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
174
|
-
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
|
175
174
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
176
175
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
177
176
|
lower_by_torch_xla2(torch.ops.aten.neg)
|
@@ -61,6 +61,7 @@ global_registry.decompositions.update(
|
|
61
61
|
torch.ops.aten._adaptive_avg_pool2d,
|
62
62
|
torch.ops.aten._adaptive_avg_pool3d,
|
63
63
|
torch.ops.aten.grid_sampler_2d,
|
64
|
+
torch.ops.aten.native_group_norm,
|
64
65
|
torch.ops.aten.native_dropout,
|
65
66
|
torch.ops.aten.reflection_pad1d,
|
66
67
|
torch.ops.aten.reflection_pad2d,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241122
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
|
-
Requires-Dist: tf-nightly>=2.19.0.
|
32
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20241121
|
33
33
|
Requires-Dist: ai-edge-litert-nightly
|
34
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
35
35
|
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=B4r6opjqsPmDJdLbwvWto6dM-0KbsjszxSL6CXmi8K8,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
29
|
ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
|
30
|
-
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
33
33
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
@@ -120,7 +120,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
|
|
120
120
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
121
121
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
122
122
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
123
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
123
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
124
124
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
125
125
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
126
126
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -134,6 +134,7 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
|
|
134
134
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
135
135
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
136
136
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
137
|
+
ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
|
137
138
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
138
139
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
139
140
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
@@ -142,6 +143,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
142
143
|
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
143
144
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
144
145
|
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
146
|
+
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
145
147
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
146
148
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
147
149
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -158,7 +160,7 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
|
|
158
160
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
159
161
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
160
162
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
161
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
|
162
164
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
163
165
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
|
164
166
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
@@ -175,16 +177,16 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
175
177
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
176
178
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
177
179
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
178
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
180
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=oQo9nxH08NnEDeZaGoCUk1kRtoEOM_f0DUOyd9nfxjg,6673
|
179
181
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
180
182
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
181
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
183
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=eH9eJqFO-BI9l4WdXfjsItODPRa18SAR_qSvJ6-7gxc,9987
|
182
184
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
183
185
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
184
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
186
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
|
185
187
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
186
188
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
187
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
189
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
|
188
190
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
189
191
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
190
192
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -194,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
194
196
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
195
197
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
196
198
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
201
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/METADATA,sha256=-YpC-ksRKR8hJ8pZET4Q2F5KbUiRmGOXPhBoEQgIuOA,1897
|
201
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
202
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/RECORD,,
|
File without changes
|
File without changes
|