ai-edge-torch-nightly 0.5.0.dev20250408__py3-none-any.whl → 0.5.0.dev20250410__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py +50 -0
- ai_edge_torch/_convert/test/test_convert.py +21 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +2 -2
- ai_edge_torch/generative/examples/gemma3/decoder.py +8 -9
- ai_edge_torch/generative/examples/gemma3/verify_util.py +4 -2
- ai_edge_torch/generative/layers/experimental/attention.py +10 -40
- ai_edge_torch/generative/layers/experimental/kv_cache.py +13 -283
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +6 -10
- ai_edge_torch/generative/layers/experimental/types.py +3 -0
- ai_edge_torch/generative/layers/kv_cache.py +81 -14
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +124 -0
- ai_edge_torch/generative/test/test_kv_cache.py +12 -19
- ai_edge_torch/generative/utilities/converter.py +8 -3
- ai_edge_torch/generative/utilities/export_config.py +3 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +19 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/METADATA +4 -2
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/RECORD +26 -24
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250408.dist-info → ai_edge_torch_nightly-0.5.0.dev20250410.dist-info}/top_level.txt +0 -0
@@ -16,24 +16,58 @@
|
|
16
16
|
"""Utility functions for externalized KV Cache."""
|
17
17
|
|
18
18
|
import dataclasses
|
19
|
-
from typing import List, Tuple
|
19
|
+
from typing import Any, List, Tuple
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
+
from ai_edge_torch.generative.layers.experimental import types
|
23
24
|
import torch
|
24
25
|
import torch.utils._pytree as pytree
|
25
26
|
|
26
27
|
|
28
|
+
KVLayout = Tuple[types.TensorDimensionMeta, types.TensorDimensionMeta]
|
29
|
+
|
30
|
+
# Define common layouts for KV Cache.
|
31
|
+
KV_LAYOUT_DEFAULT = (types.BTNH, types.BTNH)
|
32
|
+
KV_LAYOUT_TRANSPOSED = (types.BNTH, types.BNHT)
|
33
|
+
|
34
|
+
|
27
35
|
@dataclasses.dataclass
|
28
36
|
class KVCacheEntry:
|
29
37
|
"""A single cache entry that includes K and V caches.
|
30
38
|
|
31
|
-
The
|
32
|
-
(batch_size=1, kv_cache_max, num_query_groups, head_dim).
|
39
|
+
The cache layout can be customized based on different use cases.
|
33
40
|
"""
|
34
41
|
|
35
42
|
k_cache: torch.Tensor
|
36
43
|
v_cache: torch.Tensor
|
44
|
+
kv_layout: KVLayout = KV_LAYOUT_DEFAULT
|
45
|
+
|
46
|
+
@classmethod
|
47
|
+
def construct_kv_shape_from_layout(
|
48
|
+
cls,
|
49
|
+
shape_spec: types.TensorDimensionMeta,
|
50
|
+
kv_cache_max: int,
|
51
|
+
config: model_config.AttentionConfig,
|
52
|
+
batch_size: int,
|
53
|
+
) -> List[int]:
|
54
|
+
"""Constructs the shape of the key or value cache entry based on
|
55
|
+
|
56
|
+
the specified layout.
|
57
|
+
"""
|
58
|
+
output_shape = []
|
59
|
+
for dim_spec in shape_spec:
|
60
|
+
if dim_spec is types.TensorDims.BATCH:
|
61
|
+
output_shape.append(batch_size)
|
62
|
+
elif dim_spec is types.TensorDims.SEQUENCE:
|
63
|
+
output_shape.append(kv_cache_max)
|
64
|
+
elif dim_spec is types.TensorDims.NUM_HEADS:
|
65
|
+
output_shape.append(config.num_query_groups)
|
66
|
+
elif dim_spec is types.TensorDims.HEAD_DIM:
|
67
|
+
output_shape.append(config.head_dim)
|
68
|
+
else:
|
69
|
+
raise ValueError(f"Unsupported dimension spec: {dim_spec}")
|
70
|
+
return output_shape
|
37
71
|
|
38
72
|
@classmethod
|
39
73
|
def from_model_config(
|
@@ -41,14 +75,20 @@ class KVCacheEntry:
|
|
41
75
|
kv_cache_max: int,
|
42
76
|
config: model_config.AttentionConfig,
|
43
77
|
dtype: torch.dtype = torch.float32,
|
44
|
-
device: torch.device = None,
|
78
|
+
device: torch.device | None = None,
|
45
79
|
batch_size: int = 1,
|
80
|
+
kv_layout: KVLayout = KV_LAYOUT_DEFAULT,
|
46
81
|
) -> "KVCacheEntry":
|
47
82
|
"""Build an instance of the class based on model config."""
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
83
|
+
k_shape = cls.construct_kv_shape_from_layout(
|
84
|
+
kv_layout[0], kv_cache_max, config, batch_size
|
85
|
+
)
|
86
|
+
v_shape = cls.construct_kv_shape_from_layout(
|
87
|
+
kv_layout[1], kv_cache_max, config, batch_size
|
88
|
+
)
|
89
|
+
k = torch.zeros(k_shape, dtype=dtype, device=device)
|
90
|
+
v = torch.zeros(v_shape, dtype=dtype, device=device)
|
91
|
+
obj = cls(k_cache=k, v_cache=v, kv_layout=kv_layout)
|
52
92
|
return obj
|
53
93
|
|
54
94
|
|
@@ -63,8 +103,9 @@ class KVCache:
|
|
63
103
|
cls,
|
64
104
|
config: model_config.ModelConfig,
|
65
105
|
dtype: torch.dtype = torch.float32,
|
66
|
-
device: torch.device = None,
|
106
|
+
device: torch.device | None = None,
|
67
107
|
batch_size: int = 1,
|
108
|
+
kv_layout: KVLayout = KV_LAYOUT_DEFAULT,
|
68
109
|
) -> "KVCache":
|
69
110
|
"""Build an instance of the class based on model config.
|
70
111
|
|
@@ -89,6 +130,7 @@ class KVCache:
|
|
89
130
|
dtype,
|
90
131
|
device,
|
91
132
|
batch_size,
|
133
|
+
kv_layout,
|
92
134
|
)
|
93
135
|
for idx in range(config.num_layers)
|
94
136
|
]
|
@@ -104,7 +146,7 @@ class KVCache:
|
|
104
146
|
def _flatten_kvc(kvc: KVCache) -> Tuple[List[str], List[str]]:
|
105
147
|
flattened = []
|
106
148
|
flat_names = []
|
107
|
-
none_names = []
|
149
|
+
none_names = [kvc.caches[0].kv_layout]
|
108
150
|
for i, kv_entry in enumerate(kvc.caches):
|
109
151
|
flattened.append(kv_entry.k_cache)
|
110
152
|
flat_names.append(f"k_{i}")
|
@@ -121,22 +163,48 @@ def _flatten_kvc_with_keys(kvc: KVCache) -> Tuple[List, List]:
|
|
121
163
|
|
122
164
|
|
123
165
|
def _unflatten_kvc(
|
124
|
-
values: List[torch.Tensor],
|
166
|
+
values: List[torch.Tensor],
|
167
|
+
context: Tuple[List, List],
|
125
168
|
) -> KVCache:
|
126
169
|
assert len(values) % 2 == 0, "Found odd number of K and V entries."
|
127
170
|
num_layers = len(values) // 2
|
128
171
|
flat_names = context[0]
|
172
|
+
kv_layout = context[1][0]
|
129
173
|
kv_entries = []
|
130
174
|
for i in range(num_layers):
|
131
175
|
k_cache_idx = flat_names.index(f"k_{i}")
|
132
176
|
v_cache_idx = flat_names.index(f"v_{i}")
|
133
177
|
kv_entries.append(
|
134
|
-
KVCacheEntry(
|
178
|
+
KVCacheEntry(
|
179
|
+
k_cache=values[k_cache_idx],
|
180
|
+
v_cache=values[v_cache_idx],
|
181
|
+
kv_layout=kv_layout,
|
182
|
+
)
|
135
183
|
)
|
136
184
|
obj = KVCache(tuple(kv_entries))
|
137
185
|
return obj
|
138
186
|
|
139
187
|
|
188
|
+
def _flatten_kv_entry(
|
189
|
+
kv_e: KVCacheEntry,
|
190
|
+
) -> Tuple[List[torch.Tensor], Any]:
|
191
|
+
return ([kv_e.k_cache, kv_e.v_cache], kv_e.kv_layout)
|
192
|
+
|
193
|
+
|
194
|
+
def _unflatten_kv_entry(
|
195
|
+
values: List[torch.Tensor],
|
196
|
+
context: Any,
|
197
|
+
) -> KVCacheEntry:
|
198
|
+
return KVCacheEntry(*values, kv_layout=context)
|
199
|
+
|
200
|
+
|
201
|
+
pytree.register_pytree_node(
|
202
|
+
KVCacheEntry,
|
203
|
+
_flatten_kv_entry,
|
204
|
+
_unflatten_kv_entry,
|
205
|
+
serialized_type_name="",
|
206
|
+
)
|
207
|
+
|
140
208
|
pytree.register_pytree_node(
|
141
209
|
KVCache,
|
142
210
|
_flatten_kvc,
|
@@ -145,7 +213,6 @@ pytree.register_pytree_node(
|
|
145
213
|
serialized_type_name="",
|
146
214
|
)
|
147
215
|
|
148
|
-
|
149
216
|
def update(
|
150
217
|
cache: KVCacheEntry,
|
151
218
|
input_pos: torch.Tensor,
|
@@ -204,5 +271,5 @@ def _update_kv_impl(
|
|
204
271
|
k = dynamic_update_slice(cache.k_cache, k_slice, k_slice_indices)
|
205
272
|
v = dynamic_update_slice(cache.v_cache, v_slice, v_slice_indices)
|
206
273
|
|
207
|
-
updated_cache = KVCacheEntry(k, v)
|
274
|
+
updated_cache = KVCacheEntry(k, v, cache.kv_layout)
|
208
275
|
return updated_cache
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Tuple
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
|
20
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
21
|
+
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
22
|
+
from ai_edge_torch.generative.layers.experimental import types
|
23
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
24
|
+
from multipledispatch import dispatch
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
def sdpa_with_kv_update(
|
29
|
+
query: torch.Tensor,
|
30
|
+
key: torch.Tensor,
|
31
|
+
value: torch.Tensor,
|
32
|
+
kv: kv_utils.KVCacheEntry,
|
33
|
+
input_pos: torch.Tensor,
|
34
|
+
mask: torch.Tensor,
|
35
|
+
config: cfg.AttentionConfig,
|
36
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
37
|
+
return sdpa_with_kv_update_impl(
|
38
|
+
kv.kv_layout[0](), # key layout
|
39
|
+
kv.kv_layout[1](), # value layout
|
40
|
+
query=query,
|
41
|
+
key=key,
|
42
|
+
value=value,
|
43
|
+
kv=kv,
|
44
|
+
input_pos=input_pos,
|
45
|
+
mask=mask,
|
46
|
+
config=config,
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@dispatch(types.BNTH, types.BNHT)
|
51
|
+
def sdpa_with_kv_update_impl(
|
52
|
+
k_type, v_type, *args, **kwargs
|
53
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
54
|
+
query = kwargs["query"]
|
55
|
+
key = kwargs["key"]
|
56
|
+
value = kwargs["value"]
|
57
|
+
kv = kwargs["kv"]
|
58
|
+
input_pos = kwargs["input_pos"]
|
59
|
+
mask = kwargs["mask"]
|
60
|
+
config = kwargs["config"]
|
61
|
+
|
62
|
+
# Transpose k/v to specific layout for GPU implementation.
|
63
|
+
b, seq_len, n, h = query.shape
|
64
|
+
g = n // config.num_query_groups
|
65
|
+
# btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
|
66
|
+
query = query.permute(0, 2, 1, 3).reshape(
|
67
|
+
1, b * config.num_query_groups, g * seq_len, h
|
68
|
+
)
|
69
|
+
|
70
|
+
key = key.permute(0, 2, 1, 3).reshape(
|
71
|
+
1, -1, seq_len, config.head_dim
|
72
|
+
) # 1, bk, s, h
|
73
|
+
value = value.permute(0, 2, 3, 1).reshape(
|
74
|
+
1, -1, config.head_dim, seq_len
|
75
|
+
) # 1, bk, h, s
|
76
|
+
|
77
|
+
if kv is not None:
|
78
|
+
kv = kv_utils_experimental.update(kv, input_pos, key, value)
|
79
|
+
key, value = kv.k_cache, kv.v_cache
|
80
|
+
|
81
|
+
sdpa_out = sdpa.scaled_dot_product_attention(
|
82
|
+
kv,
|
83
|
+
query,
|
84
|
+
key,
|
85
|
+
value,
|
86
|
+
config.head_dim,
|
87
|
+
mask=mask,
|
88
|
+
softcap=config.logit_softcap,
|
89
|
+
) # 1, bk, gt, h
|
90
|
+
sdpa_out = (
|
91
|
+
sdpa_out.reshape(b, -1, seq_len, h)
|
92
|
+
.permute(0, 2, 1, 3)
|
93
|
+
.reshape(b, seq_len, -1)
|
94
|
+
)
|
95
|
+
return sdpa_out, kv
|
96
|
+
|
97
|
+
|
98
|
+
@dispatch(object, object)
|
99
|
+
def sdpa_with_kv_update_impl(
|
100
|
+
k_type, v_type, *args, **kwargs
|
101
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
102
|
+
query = kwargs["query"]
|
103
|
+
key = kwargs["key"]
|
104
|
+
value = kwargs["value"]
|
105
|
+
kv = kwargs["kv"]
|
106
|
+
input_pos = kwargs["input_pos"]
|
107
|
+
mask = kwargs["mask"]
|
108
|
+
config = kwargs["config"]
|
109
|
+
|
110
|
+
b, seq_len, _, _ = query.shape
|
111
|
+
if kv is not None:
|
112
|
+
kv = kv_utils.update(kv, input_pos, key, value)
|
113
|
+
key, value = kv.k_cache, kv.v_cache
|
114
|
+
|
115
|
+
sdpa_out = sdpa_default.scaled_dot_product_attention(
|
116
|
+
query,
|
117
|
+
key,
|
118
|
+
value,
|
119
|
+
config.head_dim,
|
120
|
+
mask=mask,
|
121
|
+
softcap=config.logit_softcap,
|
122
|
+
)
|
123
|
+
sdpa_out = sdpa_out.reshape(b, seq_len, -1)
|
124
|
+
return sdpa_out, kv
|
@@ -16,7 +16,6 @@
|
|
16
16
|
"""A suite of tests to validate KV Cache layer."""
|
17
17
|
|
18
18
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
-
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
20
19
|
import ai_edge_torch.generative.layers.model_config as cfg
|
21
20
|
import torch
|
22
21
|
import torch.utils._pytree as pytree
|
@@ -117,7 +116,7 @@ class TestKVLayers(googletest.TestCase):
|
|
117
116
|
self.assertEqual(input_specs[0].arg.name, "kv_k_0")
|
118
117
|
self.assertEqual(input_specs[1].arg.name, "kv_v_0")
|
119
118
|
|
120
|
-
def
|
119
|
+
def test_pytree_roundtrip_kv_cache(self):
|
121
120
|
NUM_LAYERS = 4
|
122
121
|
config = self._get_test_config(
|
123
122
|
num_layers=NUM_LAYERS,
|
@@ -125,15 +124,13 @@ class TestKVLayers(googletest.TestCase):
|
|
125
124
|
num_query_groups=1,
|
126
125
|
kv_cache_max_len=4,
|
127
126
|
)
|
128
|
-
kv =
|
129
|
-
config, batch_size=1
|
130
|
-
)
|
127
|
+
kv = kv_utils.KVCache.from_model_config(config, batch_size=1)
|
131
128
|
flat, treespec = pytree.tree_flatten(kv)
|
132
129
|
self.assertLen(flat, NUM_LAYERS * 2)
|
133
130
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
134
131
|
self.assertEqual(kv, kv_unflat)
|
135
132
|
|
136
|
-
def
|
133
|
+
def test_pytree_roundtrip_kv_cache_derived(self):
|
137
134
|
NUM_LAYERS = 4
|
138
135
|
config = self._get_test_config(
|
139
136
|
num_layers=NUM_LAYERS,
|
@@ -141,41 +138,37 @@ class TestKVLayers(googletest.TestCase):
|
|
141
138
|
num_query_groups=1,
|
142
139
|
kv_cache_max_len=4,
|
143
140
|
)
|
144
|
-
kv =
|
145
|
-
config, batch_size=1
|
141
|
+
kv = kv_utils.KVCache.from_model_config(
|
142
|
+
config, batch_size=1, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
|
146
143
|
)
|
147
144
|
flat, treespec = pytree.tree_flatten(kv)
|
148
145
|
self.assertLen(flat, NUM_LAYERS * 2)
|
149
146
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
150
147
|
self.assertEqual(kv, kv_unflat)
|
151
148
|
|
152
|
-
def
|
149
|
+
def test_pytree_roundtrip_kv_entry(self):
|
153
150
|
attn_config = cfg.AttentionConfig(
|
154
151
|
num_heads=1, head_dim=1, num_query_groups=1
|
155
152
|
)
|
156
|
-
kv =
|
157
|
-
32, attn_config
|
158
|
-
)
|
153
|
+
kv = kv_utils.KVCacheEntry.from_model_config(32, attn_config)
|
159
154
|
flat, treespec = pytree.tree_flatten(kv)
|
160
155
|
self.assertLen(flat, 2)
|
161
156
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
162
157
|
self.assertEqual(kv, kv_unflat)
|
163
|
-
self.assertIsInstance(kv_unflat,
|
158
|
+
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
164
159
|
|
165
|
-
def
|
160
|
+
def test_pytree_roundtrip_kv_entry_derived(self):
|
166
161
|
attn_config = cfg.AttentionConfig(
|
167
162
|
num_heads=1, head_dim=1, num_query_groups=1
|
168
163
|
)
|
169
|
-
kv =
|
170
|
-
32, attn_config
|
164
|
+
kv = kv_utils.KVCacheEntry.from_model_config(
|
165
|
+
32, attn_config, kv_layout=kv_utils.KV_LAYOUT_TRANSPOSED
|
171
166
|
)
|
172
167
|
flat, treespec = pytree.tree_flatten(kv)
|
173
168
|
self.assertLen(flat, 2)
|
174
169
|
kv_unflat = pytree.tree_unflatten(flat, treespec)
|
175
170
|
self.assertEqual(kv, kv_unflat)
|
176
|
-
self.assertIsInstance(
|
177
|
-
kv_unflat, kv_utils_experimental.KVCacheEntryTransposed
|
178
|
-
)
|
171
|
+
self.assertIsInstance(kv_unflat, kv_utils.KVCacheEntry)
|
179
172
|
|
180
173
|
|
181
174
|
if __name__ == "__main__":
|
@@ -20,6 +20,7 @@ import pathlib
|
|
20
20
|
from typing import Optional, Union
|
21
21
|
from absl import flags
|
22
22
|
from ai_edge_torch._convert import converter as converter_utils
|
23
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
24
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
24
25
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
26
|
from ai_edge_torch.generative.quantize import quant_recipes
|
@@ -218,9 +219,13 @@ def _export_helper(
|
|
218
219
|
[[0] for _ in range(export_config.decode_batch_size)], dtype=torch.int
|
219
220
|
)
|
220
221
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
221
|
-
prefill_kv =
|
222
|
-
|
223
|
-
|
222
|
+
prefill_kv = kv_utils.KVCache.from_model_config(
|
223
|
+
config, kv_layout=export_config.kvcache_layout
|
224
|
+
)
|
225
|
+
decode_kv = kv_utils.KVCache.from_model_config(
|
226
|
+
config,
|
227
|
+
batch_size=export_config.decode_batch_size,
|
228
|
+
kv_layout=export_config.kvcache_layout,
|
224
229
|
)
|
225
230
|
|
226
231
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
@@ -32,7 +32,9 @@ class ExportConfig:
|
|
32
32
|
# Attention masks given as inputs to the model.
|
33
33
|
prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
34
34
|
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
35
|
-
# The KV Cache
|
35
|
+
# The KV Cache layout for K and V buffers in attention.
|
36
|
+
kvcache_layout: kv_utils.KVLayout = kv_utils.KV_LAYOUT_DEFAULT
|
37
|
+
# TODO(b/409373223): The KV Cache class for K and V buffers in attention.
|
36
38
|
kvcache_cls: type = kv_utils.KVCache
|
37
39
|
# The batch size of the decode signature.
|
38
40
|
decode_batch_size: int = 1
|
@@ -301,3 +301,22 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
301
301
|
)
|
302
302
|
out = stablehlo.select(pred, self, src)
|
303
303
|
return out
|
304
|
+
|
305
|
+
|
306
|
+
# Schema:
|
307
|
+
# - aten::_to_copy(Tensor self, *, ScalarType? dtype=None,
|
308
|
+
# Layout? layout=None, Device? device=None, bool? pin_memory=None,
|
309
|
+
# bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
|
310
|
+
@lower(torch.ops.aten._to_copy.default)
|
311
|
+
def _aten_to_copy(
|
312
|
+
lctx, x: ir.Value, dtype: torch.dtype | None = None, **kwargs
|
313
|
+
):
|
314
|
+
if not dtype:
|
315
|
+
return x
|
316
|
+
|
317
|
+
return stablehlo.convert(
|
318
|
+
ir.RankedTensorType.get(
|
319
|
+
x.type.shape, utils.torch_dtype_to_ir_element_type(dtype)
|
320
|
+
),
|
321
|
+
x,
|
322
|
+
)
|
@@ -74,7 +74,6 @@ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit)
|
|
74
74
|
lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
|
75
75
|
lower_by_torch_xla2(torch.ops.aten._pdist_forward)
|
76
76
|
lower_by_torch_xla2(torch.ops.aten._softmax)
|
77
|
-
lower_by_torch_xla2(torch.ops.aten._to_copy)
|
78
77
|
lower_by_torch_xla2(torch.ops.aten._unsafe_index)
|
79
78
|
lower_by_torch_xla2(torch.ops.aten._unsafe_view)
|
80
79
|
lower_by_torch_xla2(torch.ops.aten.acos)
|
@@ -37,6 +37,7 @@ def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
|
|
37
37
|
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
38
38
|
torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
|
39
39
|
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
40
|
+
torch.bfloat16: ir.BF16Type.get,
|
40
41
|
}[dtype]
|
41
42
|
return ty_get()
|
42
43
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250410
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -25,7 +25,9 @@ License-File: LICENSE
|
|
25
25
|
Requires-Dist: numpy
|
26
26
|
Requires-Dist: scipy
|
27
27
|
Requires-Dist: safetensors
|
28
|
-
Requires-Dist:
|
28
|
+
Requires-Dist: multipledispatch
|
29
|
+
Requires-Dist: transformers
|
30
|
+
Requires-Dist: kagglehub
|
29
31
|
Requires-Dist: tabulate
|
30
32
|
Requires-Dist: torch>=2.4.0
|
31
33
|
Requires-Dist: tf-nightly>=2.19.0.dev20250101
|
@@ -2,16 +2,17 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=dQvyVQmvNYF8n8HwlkY-9fdSo-n3_bdLO9EAZpJnC8s,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=6LtGzzqT2IXprfI_vPYKhE7IuN5XmPG0xy-v0UtZ9yk,1361
|
13
13
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
15
|
+
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
15
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
16
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
@@ -26,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
28
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
29
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=6vQa0UJn2L3qxR967_-vkfLrO7JdrLLBk4BfguOtHRI,17874
|
30
31
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
32
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
33
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -65,12 +66,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSd
|
|
65
66
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
66
67
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
67
68
|
ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
68
|
-
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=
|
69
|
-
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=
|
69
|
+
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=szssSBrIUYdNIoU7LHdAq7wCqgjaY6qbV8yvTgg796Q,2945
|
70
|
+
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=n6ZQfqNEHuOhY7Pu21bb8Eax8yn2Sx5osTKJKmhonXY,15659
|
70
71
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=5PEt0aWJ5wkUBvMoWFOJ-C48ZhG7uCVb8PCKQtZ8Fvw,6485
|
71
72
|
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=uRoLoBWzFtQz5wFZfPCxbkvZsgPAqSkUUsV3977GbYc,5184
|
72
73
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
73
|
-
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=nEv0qQ0l6gSXKxP5mNwkd2lRGxpFfD4e7FNV3V76zhw,8915
|
74
75
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
75
76
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=A4uLUdqvU1NKo3seqZlWSS3fqYahnEKqNBQBJO6yXvE,1762
|
76
77
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=UKvMO85_5z1vEY5MVu6QBW_vpQYA8LWHbJI4Yx6BrCc,6592
|
@@ -153,17 +154,18 @@ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDq
|
|
153
154
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
154
155
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
155
156
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
156
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
157
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=9kkFpB9msgUDStFxEyQYYsavKPP4Dgqb_NFcd4hA4aU,8502
|
157
158
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
158
159
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
159
160
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
160
161
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
161
162
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
163
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=TcwiI1IHhcYUrTx0kpSPAJMxFfjFcDwAHHULfZm67U4,3785
|
162
164
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
163
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=
|
164
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
165
|
-
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=
|
166
|
-
ai_edge_torch/generative/layers/experimental/types.py,sha256=
|
165
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
|
166
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
167
|
+
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=8M6tC5kIUus-wbMEKDSMbCLnsobs6rgbujycsmhYa5g,2807
|
168
|
+
ai_edge_torch/generative/layers/experimental/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
167
169
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
168
170
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
169
171
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -177,7 +179,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
|
|
177
179
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
178
180
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
179
181
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
180
|
-
ai_edge_torch/generative/test/test_kv_cache.py,sha256=
|
182
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=1sXN2RPntq0PP3IEy0NkvIbzQ0Y8JhPIwRSFwO9JLlE,5728
|
181
183
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
182
184
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
183
185
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
@@ -185,8 +187,8 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=-v2Vj7Qdd3Gy
|
|
185
187
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
186
188
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
187
189
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
188
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
189
|
-
ai_edge_torch/generative/utilities/export_config.py,sha256
|
190
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=87Tzj-gLydx8_xnHxKlCbMmM1XHShstpKi8RH3xY7Xw,9757
|
191
|
+
ai_edge_torch/generative/utilities/export_config.py,sha256=8-795nyd3M34LkGhgW7hwHlJyTc2Oz1iipHK8yBhdFs,1633
|
190
192
|
ai_edge_torch/generative/utilities/loader.py,sha256=7p__m2JryWphGlYOuRxdoT4id4_tWJEVOV7y2X4H-Ak,13737
|
191
193
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=ZYX1TxpFdj573du2QCyHJlFjx4q1m12R74fp4Gwl92A,6343
|
192
194
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
@@ -203,7 +205,7 @@ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrss
|
|
203
205
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
204
206
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
205
207
|
ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
|
206
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
208
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
|
207
209
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
208
210
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
209
211
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
@@ -223,17 +225,17 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
|
|
223
225
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
224
226
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
225
227
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
226
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
228
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=4syWstepGiw3IKa8O7lciXywY7RFJ7OCWFMU1Lg3h-s,10777
|
227
229
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
228
230
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
229
231
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
230
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
232
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=JRGLXW8EQ1L-vdiVTkD1kb4AnTU05eRwZ7Ke010hZmg,11473
|
231
233
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
232
234
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
233
235
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
234
236
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
235
237
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
236
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256
|
238
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=uJaFbbgvYMI4-VFpFcMpaObNfBQl6nV0x8Yo8LaSAOE,8974
|
237
239
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
238
240
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
239
241
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
@@ -243,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
243
245
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
244
246
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
245
247
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
246
|
-
ai_edge_torch_nightly-0.5.0.
|
247
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/METADATA,sha256=8m6hxUmTT0arSNdEuo-mOyg1w9T3ekAvedvY2T6Opgw,2051
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/RECORD,,
|
File without changes
|
File without changes
|