ai-edge-torch-nightly 0.3.0.dev20241120__py3-none-any.whl → 0.3.0.dev20241122__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert_composites.py +0 -1
- ai_edge_torch/generative/layers/rotary_position_embedding.py +50 -0
- ai_edge_torch/generative/test/test_custom_dus.py +107 -0
- ai_edge_torch/generative/utilities/dynamic_update_slice.py +56 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +1 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +34 -2
- ai_edge_torch/odml_torch/lowerings/_basic.py +30 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -2
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241120.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20241120.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/RECORD +15 -13
- {ai_edge_torch_nightly-0.3.0.dev20241120.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241120.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241120.dist-info → ai_edge_torch_nightly-0.3.0.dev20241122.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,6 @@ def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
39
39
|
return TestModule(func).eval()
|
40
40
|
|
41
41
|
|
42
|
-
@googletest.skip('Temporary outage due to changes for b/377531086')
|
43
42
|
class TestConvertComposites(googletest.TestCase):
|
44
43
|
"""Tests conversion modules that are meant to be wrapped as composites."""
|
45
44
|
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
# Implementation for Rotary Position embedding. https://arxiv.org/pdf/2104.09864.pdf
|
16
|
+
from typing import Tuple
|
16
17
|
import torch
|
17
18
|
|
18
19
|
|
@@ -36,3 +37,52 @@ def apply_rope(
|
|
36
37
|
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
37
38
|
roped = (x * cos) + (rotated * sin)
|
38
39
|
return roped.transpose(1, 2).type_as(x)
|
40
|
+
|
41
|
+
|
42
|
+
def apply_rope_inline(
|
43
|
+
q: torch.Tensor,
|
44
|
+
k: torch.Tensor,
|
45
|
+
input_pos: torch.Tensor,
|
46
|
+
n_elem: int,
|
47
|
+
base: int = 10_000,
|
48
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
49
|
+
"""Computes rotary positional embedding inline for a query and key.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
q: the query tensor.
|
53
|
+
k: the key tensor.
|
54
|
+
input_pos: the sequence indices for the query and key
|
55
|
+
n_elem: number of elements of the head dimension for RoPE computation
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
output the RoPE'd query and key.
|
59
|
+
"""
|
60
|
+
|
61
|
+
if n_elem <= 0:
|
62
|
+
return q, k
|
63
|
+
|
64
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
65
|
+
freq_exponents = (2.0 / n_elem) * torch.arange(
|
66
|
+
q.shape[-1] // 2, dtype=torch.float32
|
67
|
+
)
|
68
|
+
timescale = float(base) ** freq_exponents
|
69
|
+
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
70
|
+
0
|
71
|
+
).unsqueeze(0)
|
72
|
+
cos = torch.cos(radians).type_as(q)
|
73
|
+
sin = torch.sin(radians).type_as(q)
|
74
|
+
|
75
|
+
def apply(x, sin, cos):
|
76
|
+
x = x.transpose(1, 2)
|
77
|
+
b, h, s, d = x.shape
|
78
|
+
ans = torch.split(x, d // 2, dim=-1)
|
79
|
+
x1, x2 = ans
|
80
|
+
left = x1 * cos - x2 * sin
|
81
|
+
right = x2 * cos + x1 * sin
|
82
|
+
res = torch.cat([left, right], dim=-1)
|
83
|
+
res = res.transpose(1, 2)
|
84
|
+
return res
|
85
|
+
|
86
|
+
q_roped = apply(q, sin, cos)
|
87
|
+
k_roped = apply(k, sin, cos)
|
88
|
+
return q_roped, k_roped
|
@@ -0,0 +1,107 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""A suite of tests to validate the Dynamic Update Slice Custom Op."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
20
|
+
import torch
|
21
|
+
from torch import nn
|
22
|
+
|
23
|
+
from absl.testing import absltest as googletest, parameterized
|
24
|
+
|
25
|
+
|
26
|
+
def updated_slice_matches(buffer, update, index):
|
27
|
+
indexer = [slice(i, i + d) for i, d in zip(index, update.shape)]
|
28
|
+
buf = buffer[indexer]
|
29
|
+
return torch.allclose(buf, update)
|
30
|
+
|
31
|
+
|
32
|
+
def intT(x):
|
33
|
+
return torch.tensor(x).int()
|
34
|
+
|
35
|
+
|
36
|
+
class DUSMod(nn.Module):
|
37
|
+
|
38
|
+
def forward(self, buffer, update, index):
|
39
|
+
out = dynamic_update_slice(buffer, update, index)
|
40
|
+
out = out * 2
|
41
|
+
return out
|
42
|
+
|
43
|
+
|
44
|
+
@googletest.skip('Enable this when odml_torch is default b/373387583')
|
45
|
+
class TestCustomDUS(parameterized.TestCase):
|
46
|
+
|
47
|
+
@parameterized.named_parameters(
|
48
|
+
(
|
49
|
+
'DUS_whole_buffer',
|
50
|
+
torch.randn(1, 1280, 4, 64),
|
51
|
+
torch.randn([1, 1024, 4, 64]),
|
52
|
+
[intT(0), intT(0), intT(0), intT(0)],
|
53
|
+
),
|
54
|
+
(
|
55
|
+
'DUS_kv_example',
|
56
|
+
torch.randn(2, 1280, 4, 64),
|
57
|
+
torch.randn([2, 1024, 4, 64]),
|
58
|
+
[intT(0), intT(0), intT(0), intT(0)],
|
59
|
+
),
|
60
|
+
(
|
61
|
+
'DUS_3d',
|
62
|
+
torch.randn(2, 256, 4, 64),
|
63
|
+
torch.randn([2, 256, 2, 64]),
|
64
|
+
[intT(0), intT(0), intT(2), intT(0)],
|
65
|
+
),
|
66
|
+
(
|
67
|
+
'DUS_3d_v2',
|
68
|
+
torch.randn(2, 256, 4, 64),
|
69
|
+
torch.randn([2, 256, 3, 64]),
|
70
|
+
[intT(0), intT(0), intT(1), intT(0)],
|
71
|
+
),
|
72
|
+
(
|
73
|
+
'DUS_3d_v3',
|
74
|
+
torch.randn(6, 8, 32),
|
75
|
+
torch.randn([6, 3, 32]),
|
76
|
+
[intT(0), intT(5), intT(0)],
|
77
|
+
),
|
78
|
+
(
|
79
|
+
'DUS_2d',
|
80
|
+
torch.randn(8, 32),
|
81
|
+
torch.randn([8, 12]),
|
82
|
+
[intT(0), intT(20)],
|
83
|
+
),
|
84
|
+
)
|
85
|
+
def test_opcheck_dynamic_update_slice(self, buffer, update, indices):
|
86
|
+
torch.library.opcheck(dynamic_update_slice, (buffer, update, indices))
|
87
|
+
out = dynamic_update_slice(buffer, update, indices)
|
88
|
+
self.assertTrue(updated_slice_matches(out, update, indices))
|
89
|
+
|
90
|
+
def test_exported_program(self):
|
91
|
+
buffer = torch.randn(1, 1280, 4, 64)
|
92
|
+
update = torch.randn([1, 1024, 4, 64])
|
93
|
+
index = [intT(0), intT(0), intT(0), intT(0)]
|
94
|
+
dm = DUSMod()
|
95
|
+
ep = torch.export.export(dm, (buffer, update, index))
|
96
|
+
dus_in_exported_program = False
|
97
|
+
for node in ep.graph.nodes:
|
98
|
+
if node.op == 'call_function':
|
99
|
+
if node.target.__name__.startswith('dynamic_update_slice'):
|
100
|
+
dus_in_exported_program = True
|
101
|
+
break
|
102
|
+
|
103
|
+
self.assertTrue(dus_in_exported_program)
|
104
|
+
|
105
|
+
|
106
|
+
if __name__ == '__main__':
|
107
|
+
googletest.main()
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import glob
|
18
|
+
import os
|
19
|
+
from typing import Sequence
|
20
|
+
from ai_edge_torch.odml_torch import lowerings
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# Use torch.library.custom_op to define a new custom operator.
|
27
|
+
# TODO: Update impl for multiple non-trivial start_indices
|
28
|
+
@torch.library.custom_op("ai_edge_torch::dynamic_update_slice", mutates_args=())
|
29
|
+
def dynamic_update_slice(
|
30
|
+
in_tensor: torch.Tensor,
|
31
|
+
update: torch.Tensor,
|
32
|
+
start_indices: Sequence[torch.Tensor],
|
33
|
+
) -> torch.Tensor:
|
34
|
+
compare_size = torch.tensor(in_tensor.size()) == torch.tensor(update.size())
|
35
|
+
mismatch = torch.nonzero(~compare_size, as_tuple=False)
|
36
|
+
dim = mismatch[0].item() if len(mismatch) > 0 else 0
|
37
|
+
start = start_indices[dim].item()
|
38
|
+
end = start + update.shape[dim]
|
39
|
+
indices = torch.arange(start, end).to(torch.long)
|
40
|
+
return in_tensor.index_copy(dim, indices, update)
|
41
|
+
|
42
|
+
|
43
|
+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
|
44
|
+
@dynamic_update_slice.register_fake
|
45
|
+
def _(in_tensor, update, start_indices):
|
46
|
+
return in_tensor.clone().detach()
|
47
|
+
|
48
|
+
|
49
|
+
@lowerings.lower(torch.ops.ai_edge_torch.dynamic_update_slice)
|
50
|
+
def _dynamic_update_slice_lower(
|
51
|
+
lctx,
|
52
|
+
in_tensor: ir.Value,
|
53
|
+
update: ir.Value,
|
54
|
+
start_indices: Sequence[ir.Value],
|
55
|
+
):
|
56
|
+
return stablehlo.dynamic_update_slice(in_tensor, update, start_indices)
|
@@ -185,6 +185,7 @@ def merged_bundle_to_tfl_model(
|
|
185
185
|
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
186
186
|
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
187
187
|
converter._experimental_enable_composite_direct_lowering = True
|
188
|
+
converter._experimental_enable_dynamic_update_slice = True
|
188
189
|
converter.model_origin_framework = "PYTORCH"
|
189
190
|
|
190
191
|
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
@@ -24,6 +24,7 @@ from ai_edge_torch.odml_torch.jax_bridge import utils
|
|
24
24
|
import jax
|
25
25
|
from jax._src.lib.mlir import ir
|
26
26
|
from jax._src.lib.mlir.dialects import func
|
27
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
27
28
|
import torch.utils._pytree as pytree
|
28
29
|
|
29
30
|
# Jax double (64bit) precision is required to generate StableHLO mlir with
|
@@ -143,8 +144,39 @@ def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
|
143
144
|
ir_inputs = []
|
144
145
|
|
145
146
|
results = func.CallOp(cloned_func, ir_inputs).results
|
147
|
+
|
148
|
+
if lctx.node is None:
|
149
|
+
return results[0] if len(results) == 1 else results
|
150
|
+
|
151
|
+
out_avals = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
|
152
|
+
|
153
|
+
if out_avals is None:
|
154
|
+
return results[0] if len(results) == 1 else results
|
155
|
+
|
156
|
+
def sanitize_result_elty(result, aval):
|
157
|
+
# JAX implementation may not respect aten op's output dtype. For example,
|
158
|
+
# JAX may implement a slightly different dtype upcast rules, leads to
|
159
|
+
# different result's dtype from bridged lowering and torch op output.
|
160
|
+
# Here we add an additional `stablehlo.convert` op when dtype does not
|
161
|
+
# match, to ensure the lowering's result dtype will always be the same
|
162
|
+
# as torch op's output dtype.
|
163
|
+
if aval is None:
|
164
|
+
return result
|
165
|
+
|
166
|
+
target_elty = export_utils.torch_dtype_to_ir_element_type(
|
167
|
+
lctx.ir_context, aval.dtype
|
168
|
+
)
|
169
|
+
if result.type.element_type == target_elty:
|
170
|
+
return result
|
171
|
+
return stablehlo.convert(
|
172
|
+
ir.RankedTensorType.get(result.type.shape, target_elty), result
|
173
|
+
)
|
174
|
+
|
146
175
|
if len(results) == 1:
|
147
|
-
return results[0]
|
148
|
-
return
|
176
|
+
return sanitize_result_elty(results[0], out_avals)
|
177
|
+
return [
|
178
|
+
sanitize_result_elty(result, aval)
|
179
|
+
for result, aval in zip(results, out_avals)
|
180
|
+
]
|
149
181
|
|
150
182
|
return wrapped
|
@@ -15,13 +15,17 @@
|
|
15
15
|
import math
|
16
16
|
from typing import Optional, Union
|
17
17
|
|
18
|
+
from ai_edge_torch.odml_torch import export_utils
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
18
21
|
from ai_edge_torch.odml_torch.lowerings import utils
|
19
22
|
from jax._src.lib.mlir import ir
|
20
23
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
21
24
|
import numpy as np
|
22
25
|
import torch
|
23
26
|
|
24
|
-
|
27
|
+
LoweringContext = context.LoweringContext
|
28
|
+
lower = registry.lower
|
25
29
|
|
26
30
|
|
27
31
|
# add(Tensor self, Tensor other) -> Tensor
|
@@ -211,6 +215,31 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
|
|
211
215
|
return stablehlo.floor(x)
|
212
216
|
|
213
217
|
|
218
|
+
# Schema:
|
219
|
+
# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
|
220
|
+
# Torch Reference:
|
221
|
+
# - https://pytorch.org/docs/main/generated/torch.cat.html
|
222
|
+
@lower(torch.ops.aten.cat.default)
|
223
|
+
def _aten_cat(lctx: LoweringContext, tensors, dim=0):
|
224
|
+
assert tensors
|
225
|
+
non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
|
226
|
+
out_aval = lctx.node.meta.get("tensor_meta") or lctx.node.meta.get("val")
|
227
|
+
if not non_empty_tensors:
|
228
|
+
return utils.splat(
|
229
|
+
0,
|
230
|
+
export_utils.torch_dtype_to_ir_element_type(
|
231
|
+
lctx.ir_context, out_aval.dtype
|
232
|
+
),
|
233
|
+
out_aval.shape,
|
234
|
+
)
|
235
|
+
|
236
|
+
if dim < 0:
|
237
|
+
dim = dim + len(out_aval.shape)
|
238
|
+
dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)
|
239
|
+
|
240
|
+
return stablehlo.concatenate(non_empty_tensors, dim)
|
241
|
+
|
242
|
+
|
214
243
|
# Schema:
|
215
244
|
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
|
216
245
|
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.bitwise_not)
|
|
105
105
|
lower_by_torch_xla2(torch.ops.aten.bitwise_or)
|
106
106
|
lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
|
107
107
|
lower_by_torch_xla2(torch.ops.aten.bmm)
|
108
|
-
lower_by_torch_xla2(torch.ops.aten.cat)
|
109
108
|
lower_by_torch_xla2(torch.ops.aten.ceil)
|
110
109
|
lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
|
111
110
|
lower_by_torch_xla2(torch.ops.aten.clamp.default)
|
@@ -172,7 +171,6 @@ lower_by_torch_xla2(torch.ops.aten.mm)
|
|
172
171
|
lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
173
172
|
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
174
173
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
175
|
-
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
|
176
174
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
177
175
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
178
176
|
lower_by_torch_xla2(torch.ops.aten.neg)
|
@@ -61,6 +61,7 @@ global_registry.decompositions.update(
|
|
61
61
|
torch.ops.aten._adaptive_avg_pool2d,
|
62
62
|
torch.ops.aten._adaptive_avg_pool3d,
|
63
63
|
torch.ops.aten.grid_sampler_2d,
|
64
|
+
torch.ops.aten.native_group_norm,
|
64
65
|
torch.ops.aten.native_dropout,
|
65
66
|
torch.ops.aten.reflection_pad1d,
|
66
67
|
torch.ops.aten.reflection_pad2d,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241122
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -29,7 +29,7 @@ Requires-Dist: safetensors
|
|
29
29
|
Requires-Dist: tabulate
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
|
-
Requires-Dist: tf-nightly>=2.19.0.
|
32
|
+
Requires-Dist: tf-nightly>=2.19.0.dev20241121
|
33
33
|
Requires-Dist: ai-edge-litert-nightly
|
34
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
35
35
|
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=B4r6opjqsPmDJdLbwvWto6dM-0KbsjszxSL6CXmi8K8,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
29
|
ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
|
30
|
-
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
33
33
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
@@ -120,7 +120,7 @@ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-m
|
|
120
120
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
121
121
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
122
122
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
123
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
123
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
124
124
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
125
125
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
126
126
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -134,6 +134,7 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
|
|
134
134
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
135
135
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
136
136
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
137
|
+
ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
|
137
138
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
138
139
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
139
140
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
@@ -142,6 +143,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
142
143
|
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
143
144
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
144
145
|
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
146
|
+
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
145
147
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
146
148
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
147
149
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -158,7 +160,7 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
|
|
158
160
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
159
161
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
160
162
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
161
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
163
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
|
162
164
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
163
165
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
|
164
166
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
@@ -175,16 +177,16 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
175
177
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
176
178
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
177
179
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
178
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
180
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=oQo9nxH08NnEDeZaGoCUk1kRtoEOM_f0DUOyd9nfxjg,6673
|
179
181
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
180
182
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
181
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
183
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=eH9eJqFO-BI9l4WdXfjsItODPRa18SAR_qSvJ6-7gxc,9987
|
182
184
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
183
185
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
184
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
186
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=4UyNyaR2W-vCOvj-P5lywQ1_RfLIxVE7J_GONI6CQvI,10718
|
185
187
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
186
188
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
187
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
189
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=itTt8MLbq2LoHTzRidCF2TTbh0TP7L836u99qCjP3FA,2953
|
188
190
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
189
191
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
190
192
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -194,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
194
196
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
195
197
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
196
198
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
201
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/METADATA,sha256=-YpC-ksRKR8hJ8pZET4Q2F5KbUiRmGOXPhBoEQgIuOA,1897
|
201
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
202
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/RECORD,,
|
File without changes
|
File without changes
|