ai-edge-torch-nightly 0.5.0.dev20250409__py3-none-any.whl → 0.5.0.dev20250411__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/conversion.py +1 -1
- ai_edge_torch/_convert/fx_passes/__init__.py +1 -0
- ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py +50 -0
- ai_edge_torch/_convert/test/test_convert.py +21 -0
- ai_edge_torch/generative/layers/experimental/attention.py +3 -33
- ai_edge_torch/generative/layers/experimental/kv_cache.py +2 -1
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/kv_cache.py +1 -1
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +124 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +19 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/utils.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250409.dist-info → ai_edge_torch_nightly-0.5.0.dev20250411.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.5.0.dev20250409.dist-info → ai_edge_torch_nightly-0.5.0.dev20250411.dist-info}/RECORD +20 -18
- /ai_edge_torch/generative/{layers/experimental → utilities}/types.py +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250409.dist-info → ai_edge_torch_nightly-0.5.0.dev20250411.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250409.dist-info → ai_edge_torch_nightly-0.5.0.dev20250411.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.5.0.dev20250409.dist-info → ai_edge_torch_nightly-0.5.0.dev20250411.dist-info}/top_level.txt +0 -0
@@ -40,8 +40,8 @@ def _run_convert_passes(
|
|
40
40
|
fx_passes.OptimizeLayoutTransposesPass(),
|
41
41
|
fx_passes.CanonicalizePass(),
|
42
42
|
fx_passes.BuildAtenCompositePass(),
|
43
|
-
fx_passes.CanonicalizePass(),
|
44
43
|
fx_passes.RemoveNonUserOutputsPass(),
|
44
|
+
fx_passes.CastInputsBf16ToF32Pass(),
|
45
45
|
fx_passes.CanonicalizePass(),
|
46
46
|
]
|
47
47
|
|
@@ -17,6 +17,7 @@ from typing import Sequence, Union
|
|
17
17
|
|
18
18
|
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
|
19
19
|
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
|
20
|
+
from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
|
20
21
|
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
|
22
23
|
from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Pass to cast all inputs with torch.bfloat16 type to torch.float32."""
|
16
|
+
|
17
|
+
|
18
|
+
from ai_edge_torch import fx_infra
|
19
|
+
import torch
|
20
|
+
|
21
|
+
|
22
|
+
def cast_f32(x):
|
23
|
+
return x.to(torch.float32)
|
24
|
+
|
25
|
+
|
26
|
+
class CastInputsBf16ToF32Pass(fx_infra.ExportedProgramPassBase):
|
27
|
+
"""This pass casts all inputs with torch.bfloat16 type to torch.float32."""
|
28
|
+
|
29
|
+
def call(self, exported_program: torch.export.ExportedProgram):
|
30
|
+
modified = False
|
31
|
+
for node in exported_program.graph.nodes:
|
32
|
+
if (
|
33
|
+
node.op == "placeholder"
|
34
|
+
and node.meta.get("val").dtype == torch.bfloat16
|
35
|
+
):
|
36
|
+
if not node.users:
|
37
|
+
continue
|
38
|
+
|
39
|
+
modified = True
|
40
|
+
user = next(iter(node.users))
|
41
|
+
with exported_program.graph.inserting_before(user):
|
42
|
+
cast_node = exported_program.graph.call_function(
|
43
|
+
cast_f32,
|
44
|
+
(node,),
|
45
|
+
)
|
46
|
+
node.replace_all_uses_with(cast_node)
|
47
|
+
cast_node.replace_input_with(cast_node, node)
|
48
|
+
|
49
|
+
exported_program.graph_module.recompile()
|
50
|
+
return fx_infra.ExportedProgramPassResult(exported_program, modified)
|
@@ -553,6 +553,27 @@ class TestConvert(googletest.TestCase):
|
|
553
553
|
self.fail(f"PT2E conversion failed: {err}")
|
554
554
|
# pylint: enable=broad-except
|
555
555
|
|
556
|
+
def test_convert_model_with_bfloat16_inputs(self):
|
557
|
+
"""Test converting a simple model with torch.bfloat16 input.
|
558
|
+
|
559
|
+
bf16 inputs would remain in converted model signature but be casted to f32
|
560
|
+
right after the model inputs.
|
561
|
+
"""
|
562
|
+
|
563
|
+
class SampleModel(nn.Module):
|
564
|
+
|
565
|
+
def forward(self, x: torch.Tensor):
|
566
|
+
return (x + 1) * 1.2
|
567
|
+
|
568
|
+
model = SampleModel().eval()
|
569
|
+
args = (torch.randn(10, 10).to(torch.bfloat16),)
|
570
|
+
# pylint: disable=broad-except
|
571
|
+
try:
|
572
|
+
ai_edge_torch.convert(model, args)
|
573
|
+
except Exception as err:
|
574
|
+
self.fail(f"Conversion failed with bloat16 inputs: {err}")
|
575
|
+
# pylint: enable=broad-except
|
576
|
+
|
556
577
|
|
557
578
|
if __name__ == "__main__":
|
558
579
|
googletest.main()
|
@@ -24,8 +24,7 @@ from typing import Optional, Tuple, Union
|
|
24
24
|
from ai_edge_torch.generative.layers import builder
|
25
25
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
26
26
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
27
|
-
from ai_edge_torch.generative.layers
|
28
|
-
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
27
|
+
from ai_edge_torch.generative.layers import sdpa_with_kv_update
|
29
28
|
import ai_edge_torch.generative.layers.model_config as cfg
|
30
29
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
31
30
|
import torch
|
@@ -147,7 +146,6 @@ class CausalSelfAttention(nn.Module):
|
|
147
146
|
self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
|
148
147
|
self.config = config
|
149
148
|
self.enable_hlfb = enable_hlfb
|
150
|
-
self.sdpa_func = sdpa.scaled_dot_product_attention
|
151
149
|
|
152
150
|
def forward(
|
153
151
|
self,
|
@@ -221,36 +219,8 @@ class CausalSelfAttention(nn.Module):
|
|
221
219
|
cos, sin = rope
|
222
220
|
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
223
221
|
|
224
|
-
|
225
|
-
|
226
|
-
g = n // self.config.num_query_groups
|
227
|
-
# btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
|
228
|
-
q = q.permute(0, 2, 1, 3).reshape(
|
229
|
-
1, b * self.config.num_query_groups, g * T, h
|
230
|
-
)
|
231
|
-
|
232
|
-
k = k.permute(0, 2, 1, 3).reshape(
|
233
|
-
1, -1, T, self.config.head_dim
|
234
|
-
) # 1, bk, s, h
|
235
|
-
v = v.permute(0, 2, 3, 1).reshape(
|
236
|
-
1, -1, self.config.head_dim, T
|
237
|
-
) # 1, bk, h, s
|
238
|
-
|
239
|
-
if kv_cache is not None:
|
240
|
-
kv_cache = kv_utils_experimental.update(kv_cache, input_pos, k, v)
|
241
|
-
k, v = kv_cache.k_cache, kv_cache.v_cache
|
242
|
-
|
243
|
-
sdpa_out = self.sdpa_func(
|
244
|
-
kv_cache,
|
245
|
-
q,
|
246
|
-
k,
|
247
|
-
v,
|
248
|
-
self.config.head_dim,
|
249
|
-
mask=mask,
|
250
|
-
softcap=self.config.logit_softcap,
|
251
|
-
) # 1, bk, gt, h
|
252
|
-
sdpa_out = (
|
253
|
-
sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
|
222
|
+
sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
|
223
|
+
q, k, v, kv_cache, input_pos, mask, self.config
|
254
224
|
)
|
255
225
|
|
256
226
|
# Compute the output projection.
|
@@ -44,7 +44,8 @@ def update(
|
|
44
44
|
assert (
|
45
45
|
cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
|
46
46
|
), "KV entry must have transposed layout."
|
47
|
-
|
47
|
+
update_kv_cache = _update_kv_impl_transposed
|
48
|
+
return update_kv_cache(cache, input_pos, k_slice, v_slice)
|
48
49
|
|
49
50
|
|
50
51
|
def _get_slice_indices(
|
@@ -20,7 +20,7 @@ from typing import Optional
|
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.custom_ops import bmm_4d as bmm_lib
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
|
-
from ai_edge_torch.generative.
|
23
|
+
from ai_edge_torch.generative.utilities import types
|
24
24
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
25
25
|
from multipledispatch import dispatch
|
26
26
|
import torch
|
@@ -82,7 +82,6 @@ def _sdpa(k_type, v_type, *args, **kwargs):
|
|
82
82
|
padded_logits = logits + mask
|
83
83
|
padded_logits = padded_logits.reshape(1, bk, gt, s)
|
84
84
|
probs = F.softmax(padded_logits, dim=-1).type_as(key)
|
85
|
-
|
86
85
|
encoded = bmm_lib.bmm_4d(probs, value)
|
87
86
|
|
88
87
|
return encoded # 1, bk, gt, h
|
@@ -20,7 +20,7 @@ from typing import Any, List, Tuple
|
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.custom_ops.dynamic_update_slice import dynamic_update_slice
|
22
22
|
from ai_edge_torch.generative.layers import model_config
|
23
|
-
from ai_edge_torch.generative.
|
23
|
+
from ai_edge_torch.generative.utilities import types
|
24
24
|
import torch
|
25
25
|
import torch.utils._pytree as pytree
|
26
26
|
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Tuple
|
18
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
19
|
+
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
|
20
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
|
21
|
+
from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
23
|
+
from ai_edge_torch.generative.utilities import types
|
24
|
+
from multipledispatch import dispatch
|
25
|
+
import torch
|
26
|
+
|
27
|
+
|
28
|
+
def sdpa_with_kv_update(
|
29
|
+
query: torch.Tensor,
|
30
|
+
key: torch.Tensor,
|
31
|
+
value: torch.Tensor,
|
32
|
+
kv: kv_utils.KVCacheEntry,
|
33
|
+
input_pos: torch.Tensor,
|
34
|
+
mask: torch.Tensor,
|
35
|
+
config: cfg.AttentionConfig,
|
36
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
37
|
+
return sdpa_with_kv_update_impl(
|
38
|
+
kv.kv_layout[0](), # key layout
|
39
|
+
kv.kv_layout[1](), # value layout
|
40
|
+
query=query,
|
41
|
+
key=key,
|
42
|
+
value=value,
|
43
|
+
kv=kv,
|
44
|
+
input_pos=input_pos,
|
45
|
+
mask=mask,
|
46
|
+
config=config,
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@dispatch(types.BNTH, types.BNHT)
|
51
|
+
def sdpa_with_kv_update_impl(
|
52
|
+
k_type, v_type, *args, **kwargs
|
53
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
54
|
+
query = kwargs["query"]
|
55
|
+
key = kwargs["key"]
|
56
|
+
value = kwargs["value"]
|
57
|
+
kv = kwargs["kv"]
|
58
|
+
input_pos = kwargs["input_pos"]
|
59
|
+
mask = kwargs["mask"]
|
60
|
+
config = kwargs["config"]
|
61
|
+
|
62
|
+
# Transpose k/v to specific layout for GPU implementation.
|
63
|
+
b, seq_len, n, h = query.shape
|
64
|
+
g = n // config.num_query_groups
|
65
|
+
# btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
|
66
|
+
query = query.permute(0, 2, 1, 3).reshape(
|
67
|
+
1, b * config.num_query_groups, g * seq_len, h
|
68
|
+
)
|
69
|
+
|
70
|
+
key = key.permute(0, 2, 1, 3).reshape(
|
71
|
+
1, -1, seq_len, config.head_dim
|
72
|
+
) # 1, bk, s, h
|
73
|
+
value = value.permute(0, 2, 3, 1).reshape(
|
74
|
+
1, -1, config.head_dim, seq_len
|
75
|
+
) # 1, bk, h, s
|
76
|
+
|
77
|
+
if kv is not None:
|
78
|
+
kv = kv_utils_experimental.update(kv, input_pos, key, value)
|
79
|
+
key, value = kv.k_cache, kv.v_cache
|
80
|
+
|
81
|
+
sdpa_out = sdpa.scaled_dot_product_attention(
|
82
|
+
kv,
|
83
|
+
query,
|
84
|
+
key,
|
85
|
+
value,
|
86
|
+
config.head_dim,
|
87
|
+
mask=mask,
|
88
|
+
softcap=config.logit_softcap,
|
89
|
+
) # 1, bk, gt, h
|
90
|
+
sdpa_out = (
|
91
|
+
sdpa_out.reshape(b, -1, seq_len, h)
|
92
|
+
.permute(0, 2, 1, 3)
|
93
|
+
.reshape(b, seq_len, -1)
|
94
|
+
)
|
95
|
+
return sdpa_out, kv
|
96
|
+
|
97
|
+
|
98
|
+
@dispatch(object, object)
|
99
|
+
def sdpa_with_kv_update_impl(
|
100
|
+
k_type, v_type, *args, **kwargs
|
101
|
+
) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
|
102
|
+
query = kwargs["query"]
|
103
|
+
key = kwargs["key"]
|
104
|
+
value = kwargs["value"]
|
105
|
+
kv = kwargs["kv"]
|
106
|
+
input_pos = kwargs["input_pos"]
|
107
|
+
mask = kwargs["mask"]
|
108
|
+
config = kwargs["config"]
|
109
|
+
|
110
|
+
b, seq_len, _, _ = query.shape
|
111
|
+
if kv is not None:
|
112
|
+
kv = kv_utils.update(kv, input_pos, key, value)
|
113
|
+
key, value = kv.k_cache, kv.v_cache
|
114
|
+
|
115
|
+
sdpa_out = sdpa_default.scaled_dot_product_attention(
|
116
|
+
query,
|
117
|
+
key,
|
118
|
+
value,
|
119
|
+
config.head_dim,
|
120
|
+
mask=mask,
|
121
|
+
softcap=config.logit_softcap,
|
122
|
+
)
|
123
|
+
sdpa_out = sdpa_out.reshape(b, seq_len, -1)
|
124
|
+
return sdpa_out, kv
|
@@ -301,3 +301,22 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
301
301
|
)
|
302
302
|
out = stablehlo.select(pred, self, src)
|
303
303
|
return out
|
304
|
+
|
305
|
+
|
306
|
+
# Schema:
|
307
|
+
# - aten::_to_copy(Tensor self, *, ScalarType? dtype=None,
|
308
|
+
# Layout? layout=None, Device? device=None, bool? pin_memory=None,
|
309
|
+
# bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
|
310
|
+
@lower(torch.ops.aten._to_copy.default)
|
311
|
+
def _aten_to_copy(
|
312
|
+
lctx, x: ir.Value, dtype: torch.dtype | None = None, **kwargs
|
313
|
+
):
|
314
|
+
if not dtype:
|
315
|
+
return x
|
316
|
+
|
317
|
+
return stablehlo.convert(
|
318
|
+
ir.RankedTensorType.get(
|
319
|
+
x.type.shape, utils.torch_dtype_to_ir_element_type(dtype)
|
320
|
+
),
|
321
|
+
x,
|
322
|
+
)
|
@@ -74,7 +74,6 @@ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit)
|
|
74
74
|
lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
|
75
75
|
lower_by_torch_xla2(torch.ops.aten._pdist_forward)
|
76
76
|
lower_by_torch_xla2(torch.ops.aten._softmax)
|
77
|
-
lower_by_torch_xla2(torch.ops.aten._to_copy)
|
78
77
|
lower_by_torch_xla2(torch.ops.aten._unsafe_index)
|
79
78
|
lower_by_torch_xla2(torch.ops.aten._unsafe_view)
|
80
79
|
lower_by_torch_xla2(torch.ops.aten.acos)
|
@@ -37,6 +37,7 @@ def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
|
|
37
37
|
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
|
38
38
|
torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
|
39
39
|
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
|
40
|
+
torch.bfloat16: ir.BF16Type.get,
|
40
41
|
}[dtype]
|
41
42
|
return ty_get()
|
42
43
|
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.dev20250411
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,16 +2,17 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=eSFSbkcpm9gvoLSoHH1AACS_xhGjYFxHgSKTrIlbpsQ,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
12
|
-
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=
|
12
|
+
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=6LtGzzqT2IXprfI_vPYKhE7IuN5XmPG0xy-v0UtZ9yk,1361
|
13
13
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
|
15
|
+
ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
|
15
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
|
16
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
@@ -26,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
26
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
27
28
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
29
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=6vQa0UJn2L3qxR967_-vkfLrO7JdrLLBk4BfguOtHRI,17874
|
30
31
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
31
32
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
33
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -153,17 +154,17 @@ ai_edge_torch/generative/layers/attention.py,sha256=wLZ1jgUlcODBWgK3hnnhclHuuQDq
|
|
153
154
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
154
155
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
155
156
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
156
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
157
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=WNH_Ab29eXKXs8HAm3Wmdv_LBzO6PQW5d34Eo6Yzgd0,8492
|
157
158
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
158
159
|
ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQoJ-K8D10cKR9229SAw,8355
|
159
160
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
160
161
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
161
162
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
163
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=oo9h7pi0GcuylRgp2yUuvUJCrhj03aoWt_fP7EDP4LM,3775
|
162
164
|
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
163
|
-
ai_edge_torch/generative/layers/experimental/attention.py,sha256=
|
164
|
-
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=
|
165
|
-
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256
|
166
|
-
ai_edge_torch/generative/layers/experimental/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
165
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
|
166
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
|
167
|
+
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=YFcIGOkaNb-vvQKjI-G9-bC2Z1W0O_qRyIZPlsLl72U,2797
|
167
168
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
168
169
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
169
170
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -193,6 +194,7 @@ ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_
|
|
193
194
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
194
195
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
195
196
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
197
|
+
ai_edge_torch/generative/utilities/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
|
196
198
|
ai_edge_torch/generative/utilities/verifier.py,sha256=gsBGv-WgeP73Bmao29CiTIQOC1YZ43IUJcGzytpcZyM,12095
|
197
199
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
198
200
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
|
@@ -203,7 +205,7 @@ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrss
|
|
203
205
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
204
206
|
ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
|
205
207
|
ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
|
206
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
208
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
|
207
209
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
208
210
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
209
211
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
@@ -223,17 +225,17 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
|
|
223
225
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
224
226
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
225
227
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
226
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
228
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=4syWstepGiw3IKa8O7lciXywY7RFJ7OCWFMU1Lg3h-s,10777
|
227
229
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
228
230
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
229
231
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
230
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
232
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=JRGLXW8EQ1L-vdiVTkD1kb4AnTU05eRwZ7Ke010hZmg,11473
|
231
233
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
232
234
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
233
235
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
234
236
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
235
237
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
236
|
-
ai_edge_torch/odml_torch/lowerings/utils.py,sha256
|
238
|
+
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=uJaFbbgvYMI4-VFpFcMpaObNfBQl6nV0x8Yo8LaSAOE,8974
|
237
239
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
238
240
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
239
241
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
@@ -243,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
243
245
|
ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
|
244
246
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
245
247
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
246
|
-
ai_edge_torch_nightly-0.5.0.
|
247
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
-
ai_edge_torch_nightly-0.5.0.
|
249
|
-
ai_edge_torch_nightly-0.5.0.
|
250
|
-
ai_edge_torch_nightly-0.5.0.
|
248
|
+
ai_edge_torch_nightly-0.5.0.dev20250411.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
249
|
+
ai_edge_torch_nightly-0.5.0.dev20250411.dist-info/METADATA,sha256=NcEM89nSJwDWPZTuq137MbNXbp3zMHbtDZpR0_Q9z08,2051
|
250
|
+
ai_edge_torch_nightly-0.5.0.dev20250411.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
251
|
+
ai_edge_torch_nightly-0.5.0.dev20250411.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
252
|
+
ai_edge_torch_nightly-0.5.0.dev20250411.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|