ai-edge-torch-nightly 0.5.0.dev20250409__py3-none-any.whl → 0.5.0.dev20250410__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,8 +40,8 @@ def _run_convert_passes(
40
40
  fx_passes.OptimizeLayoutTransposesPass(),
41
41
  fx_passes.CanonicalizePass(),
42
42
  fx_passes.BuildAtenCompositePass(),
43
- fx_passes.CanonicalizePass(),
44
43
  fx_passes.RemoveNonUserOutputsPass(),
44
+ fx_passes.CastInputsBf16ToF32Pass(),
45
45
  fx_passes.CanonicalizePass(),
46
46
  ]
47
47
 
@@ -17,6 +17,7 @@ from typing import Sequence, Union
17
17
 
18
18
  from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass
19
19
  from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass
20
+ from ai_edge_torch._convert.fx_passes.cast_inputs_bf16_to_f32_pass import CastInputsBf16ToF32Pass
20
21
  from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
21
22
  from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
22
23
  from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
@@ -0,0 +1,50 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Pass to cast all inputs with torch.bfloat16 type to torch.float32."""
16
+
17
+
18
+ from ai_edge_torch import fx_infra
19
+ import torch
20
+
21
+
22
+ def cast_f32(x):
23
+ return x.to(torch.float32)
24
+
25
+
26
+ class CastInputsBf16ToF32Pass(fx_infra.ExportedProgramPassBase):
27
+ """This pass casts all inputs with torch.bfloat16 type to torch.float32."""
28
+
29
+ def call(self, exported_program: torch.export.ExportedProgram):
30
+ modified = False
31
+ for node in exported_program.graph.nodes:
32
+ if (
33
+ node.op == "placeholder"
34
+ and node.meta.get("val").dtype == torch.bfloat16
35
+ ):
36
+ if not node.users:
37
+ continue
38
+
39
+ modified = True
40
+ user = next(iter(node.users))
41
+ with exported_program.graph.inserting_before(user):
42
+ cast_node = exported_program.graph.call_function(
43
+ cast_f32,
44
+ (node,),
45
+ )
46
+ node.replace_all_uses_with(cast_node)
47
+ cast_node.replace_input_with(cast_node, node)
48
+
49
+ exported_program.graph_module.recompile()
50
+ return fx_infra.ExportedProgramPassResult(exported_program, modified)
@@ -553,6 +553,27 @@ class TestConvert(googletest.TestCase):
553
553
  self.fail(f"PT2E conversion failed: {err}")
554
554
  # pylint: enable=broad-except
555
555
 
556
+ def test_convert_model_with_bfloat16_inputs(self):
557
+ """Test converting a simple model with torch.bfloat16 input.
558
+
559
+ bf16 inputs would remain in converted model signature but be casted to f32
560
+ right after the model inputs.
561
+ """
562
+
563
+ class SampleModel(nn.Module):
564
+
565
+ def forward(self, x: torch.Tensor):
566
+ return (x + 1) * 1.2
567
+
568
+ model = SampleModel().eval()
569
+ args = (torch.randn(10, 10).to(torch.bfloat16),)
570
+ # pylint: disable=broad-except
571
+ try:
572
+ ai_edge_torch.convert(model, args)
573
+ except Exception as err:
574
+ self.fail(f"Conversion failed with bloat16 inputs: {err}")
575
+ # pylint: enable=broad-except
576
+
556
577
 
557
578
  if __name__ == "__main__":
558
579
  googletest.main()
@@ -24,8 +24,7 @@ from typing import Optional, Tuple, Union
24
24
  from ai_edge_torch.generative.layers import builder
25
25
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
26
26
  from ai_edge_torch.generative.layers import lora as lora_utils
27
- from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
28
- from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
27
+ from ai_edge_torch.generative.layers import sdpa_with_kv_update
29
28
  import ai_edge_torch.generative.layers.model_config as cfg
30
29
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
31
30
  import torch
@@ -147,7 +146,6 @@ class CausalSelfAttention(nn.Module):
147
146
  self.key_norm = builder.build_norm(config.head_dim, config.key_norm_config)
148
147
  self.config = config
149
148
  self.enable_hlfb = enable_hlfb
150
- self.sdpa_func = sdpa.scaled_dot_product_attention
151
149
 
152
150
  def forward(
153
151
  self,
@@ -221,36 +219,8 @@ class CausalSelfAttention(nn.Module):
221
219
  cos, sin = rope
222
220
  q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
223
221
 
224
- # Transpose k/v to specific layout for GPU implementation.
225
- b, _, n, h = q.shape
226
- g = n // self.config.num_query_groups
227
- # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
228
- q = q.permute(0, 2, 1, 3).reshape(
229
- 1, b * self.config.num_query_groups, g * T, h
230
- )
231
-
232
- k = k.permute(0, 2, 1, 3).reshape(
233
- 1, -1, T, self.config.head_dim
234
- ) # 1, bk, s, h
235
- v = v.permute(0, 2, 3, 1).reshape(
236
- 1, -1, self.config.head_dim, T
237
- ) # 1, bk, h, s
238
-
239
- if kv_cache is not None:
240
- kv_cache = kv_utils_experimental.update(kv_cache, input_pos, k, v)
241
- k, v = kv_cache.k_cache, kv_cache.v_cache
242
-
243
- sdpa_out = self.sdpa_func(
244
- kv_cache,
245
- q,
246
- k,
247
- v,
248
- self.config.head_dim,
249
- mask=mask,
250
- softcap=self.config.logit_softcap,
251
- ) # 1, bk, gt, h
252
- sdpa_out = (
253
- sdpa_out.reshape(B, -1, T, h).permute(0, 2, 1, 3).reshape(B, T, -1)
222
+ sdpa_out, kv_cache = sdpa_with_kv_update.sdpa_with_kv_update(
223
+ q, k, v, kv_cache, input_pos, mask, self.config
254
224
  )
255
225
 
256
226
  # Compute the output projection.
@@ -44,7 +44,8 @@ def update(
44
44
  assert (
45
45
  cache.kv_layout == kv_utils.KV_LAYOUT_TRANSPOSED
46
46
  ), "KV entry must have transposed layout."
47
- return _update_kv_impl_transposed(cache, input_pos, k_slice, v_slice)
47
+ update_kv_cache = _update_kv_impl_transposed
48
+ return update_kv_cache(cache, input_pos, k_slice, v_slice)
48
49
 
49
50
 
50
51
  def _get_slice_indices(
@@ -82,7 +82,6 @@ def _sdpa(k_type, v_type, *args, **kwargs):
82
82
  padded_logits = logits + mask
83
83
  padded_logits = padded_logits.reshape(1, bk, gt, s)
84
84
  probs = F.softmax(padded_logits, dim=-1).type_as(key)
85
-
86
85
  encoded = bmm_lib.bmm_4d(probs, value)
87
86
 
88
87
  return encoded # 1, bk, gt, h
@@ -0,0 +1,124 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ from typing import Tuple
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
19
+ from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa_default
20
+ from ai_edge_torch.generative.layers.experimental import kv_cache as kv_utils_experimental
21
+ from ai_edge_torch.generative.layers.experimental import scaled_dot_product_attention as sdpa
22
+ from ai_edge_torch.generative.layers.experimental import types
23
+ import ai_edge_torch.generative.layers.model_config as cfg
24
+ from multipledispatch import dispatch
25
+ import torch
26
+
27
+
28
+ def sdpa_with_kv_update(
29
+ query: torch.Tensor,
30
+ key: torch.Tensor,
31
+ value: torch.Tensor,
32
+ kv: kv_utils.KVCacheEntry,
33
+ input_pos: torch.Tensor,
34
+ mask: torch.Tensor,
35
+ config: cfg.AttentionConfig,
36
+ ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
37
+ return sdpa_with_kv_update_impl(
38
+ kv.kv_layout[0](), # key layout
39
+ kv.kv_layout[1](), # value layout
40
+ query=query,
41
+ key=key,
42
+ value=value,
43
+ kv=kv,
44
+ input_pos=input_pos,
45
+ mask=mask,
46
+ config=config,
47
+ )
48
+
49
+
50
+ @dispatch(types.BNTH, types.BNHT)
51
+ def sdpa_with_kv_update_impl(
52
+ k_type, v_type, *args, **kwargs
53
+ ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
54
+ query = kwargs["query"]
55
+ key = kwargs["key"]
56
+ value = kwargs["value"]
57
+ kv = kwargs["kv"]
58
+ input_pos = kwargs["input_pos"]
59
+ mask = kwargs["mask"]
60
+ config = kwargs["config"]
61
+
62
+ # Transpose k/v to specific layout for GPU implementation.
63
+ b, seq_len, n, h = query.shape
64
+ g = n // config.num_query_groups
65
+ # btnh -> bnth -> b(kg)th -> 1(bk)(gt)h
66
+ query = query.permute(0, 2, 1, 3).reshape(
67
+ 1, b * config.num_query_groups, g * seq_len, h
68
+ )
69
+
70
+ key = key.permute(0, 2, 1, 3).reshape(
71
+ 1, -1, seq_len, config.head_dim
72
+ ) # 1, bk, s, h
73
+ value = value.permute(0, 2, 3, 1).reshape(
74
+ 1, -1, config.head_dim, seq_len
75
+ ) # 1, bk, h, s
76
+
77
+ if kv is not None:
78
+ kv = kv_utils_experimental.update(kv, input_pos, key, value)
79
+ key, value = kv.k_cache, kv.v_cache
80
+
81
+ sdpa_out = sdpa.scaled_dot_product_attention(
82
+ kv,
83
+ query,
84
+ key,
85
+ value,
86
+ config.head_dim,
87
+ mask=mask,
88
+ softcap=config.logit_softcap,
89
+ ) # 1, bk, gt, h
90
+ sdpa_out = (
91
+ sdpa_out.reshape(b, -1, seq_len, h)
92
+ .permute(0, 2, 1, 3)
93
+ .reshape(b, seq_len, -1)
94
+ )
95
+ return sdpa_out, kv
96
+
97
+
98
+ @dispatch(object, object)
99
+ def sdpa_with_kv_update_impl(
100
+ k_type, v_type, *args, **kwargs
101
+ ) -> Tuple[torch.Tensor, kv_utils.KVCacheEntry]:
102
+ query = kwargs["query"]
103
+ key = kwargs["key"]
104
+ value = kwargs["value"]
105
+ kv = kwargs["kv"]
106
+ input_pos = kwargs["input_pos"]
107
+ mask = kwargs["mask"]
108
+ config = kwargs["config"]
109
+
110
+ b, seq_len, _, _ = query.shape
111
+ if kv is not None:
112
+ kv = kv_utils.update(kv, input_pos, key, value)
113
+ key, value = kv.k_cache, kv.v_cache
114
+
115
+ sdpa_out = sdpa_default.scaled_dot_product_attention(
116
+ query,
117
+ key,
118
+ value,
119
+ config.head_dim,
120
+ mask=mask,
121
+ softcap=config.logit_softcap,
122
+ )
123
+ sdpa_out = sdpa_out.reshape(b, seq_len, -1)
124
+ return sdpa_out, kv
@@ -52,6 +52,7 @@ def torch_dtype_to_tf(dtype):
52
52
  torch.int32: tf.int32,
53
53
  torch.int16: tf.int16,
54
54
  torch.bool: tf.bool,
55
+ torch.bfloat16: tf.bfloat16,
55
56
  }.get(dtype)
56
57
 
57
58
 
@@ -301,3 +301,22 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
301
301
  )
302
302
  out = stablehlo.select(pred, self, src)
303
303
  return out
304
+
305
+
306
+ # Schema:
307
+ # - aten::_to_copy(Tensor self, *, ScalarType? dtype=None,
308
+ # Layout? layout=None, Device? device=None, bool? pin_memory=None,
309
+ # bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
310
+ @lower(torch.ops.aten._to_copy.default)
311
+ def _aten_to_copy(
312
+ lctx, x: ir.Value, dtype: torch.dtype | None = None, **kwargs
313
+ ):
314
+ if not dtype:
315
+ return x
316
+
317
+ return stablehlo.convert(
318
+ ir.RankedTensorType.get(
319
+ x.type.shape, utils.torch_dtype_to_ir_element_type(dtype)
320
+ ),
321
+ x,
322
+ )
@@ -74,7 +74,6 @@ lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit)
74
74
  lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
75
75
  lower_by_torch_xla2(torch.ops.aten._pdist_forward)
76
76
  lower_by_torch_xla2(torch.ops.aten._softmax)
77
- lower_by_torch_xla2(torch.ops.aten._to_copy)
78
77
  lower_by_torch_xla2(torch.ops.aten._unsafe_index)
79
78
  lower_by_torch_xla2(torch.ops.aten._unsafe_view)
80
79
  lower_by_torch_xla2(torch.ops.aten.acos)
@@ -37,6 +37,7 @@ def torch_dtype_to_ir_element_type(dtype) -> ir.Type:
37
37
  torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
38
38
  torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
39
39
  torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
40
+ torch.bfloat16: ir.BF16Type.get,
40
41
  }[dtype]
41
42
  return ty_get()
42
43
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.5.0.dev20250409"
16
+ __version__ = "0.5.0.dev20250410"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.5.0.dev20250409
3
+ Version: 0.5.0.dev20250410
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,16 +2,17 @@ ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,120
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
5
- ai_edge_torch/version.py,sha256=DEYqmCDZNmwuMxnxrFvcTEaDp6Z_BVHJaZMYjVQ2ijU,706
5
+ ai_edge_torch/version.py,sha256=dQvyVQmvNYF8n8HwlkY-9fdSo-n3_bdLO9EAZpJnC8s,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
- ai_edge_torch/_convert/conversion.py,sha256=gpXQnifODU-mWxkUZw_3ov1lEYBw1SPVIcqj5k7pTGo,5550
7
+ ai_edge_torch/_convert/conversion.py,sha256=GPDsXhfECjDzOut4vh_d9qWcyfpxobFMBTsC7MyJbM0,5557
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
9
9
  ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
10
10
  ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
11
11
  ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
12
- ai_edge_torch/_convert/fx_passes/__init__.py,sha256=dG4WIICk0FqCH9euvbYHHsybRN7B1cYcuxN_OYxmjWo,1263
12
+ ai_edge_torch/_convert/fx_passes/__init__.py,sha256=6LtGzzqT2IXprfI_vPYKhE7IuN5XmPG0xy-v0UtZ9yk,1361
13
13
  ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
14
14
  ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
15
+ ai_edge_torch/_convert/fx_passes/cast_inputs_bf16_to_f32_pass.py,sha256=90YxLVAAkiA3qKr4Um__JmPeC1bTeA2PxBCj0GETq1Q,1748
15
16
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
16
17
  ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
17
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
@@ -26,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
28
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
29
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=o6tuJkD-ESaQxLxJpN104qpchm3LCtPmHinzQxe6PSg,17226
30
+ ai_edge_torch/_convert/test/test_convert.py,sha256=6vQa0UJn2L3qxR967_-vkfLrO7JdrLLBk4BfguOtHRI,17874
30
31
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
32
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
33
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -159,10 +160,11 @@ ai_edge_torch/generative/layers/model_config.py,sha256=nLXvTkDAIHJQ0PTaWODF8oxJQ
159
160
  ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
160
161
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
161
162
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
163
+ ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=TcwiI1IHhcYUrTx0kpSPAJMxFfjFcDwAHHULfZm67U4,3785
162
164
  ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
163
- ai_edge_torch/generative/layers/experimental/attention.py,sha256=oW8cxv0pXcesnyGz6bXacRmlvHPfKNnJnls_Qb4L_aQ,8968
164
- ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=PlgL2bNNKasu3wFr3Iu9wbATWluWZt3_s4tzglJu2tM,2942
165
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=-ztTIgdec5gXkOVe6FXk3PMeS2HoL6-mBfDBdjQIcLQ,2808
165
+ ai_edge_torch/generative/layers/experimental/attention.py,sha256=XYbo1KlmiMEuwArye0Ul86jEsdxLr1RG-usRpidZiT8,8001
166
+ ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=zgpFVftOfllvjh9-UEBSvUbm152SnQETn29rUMMMvAM,2978
167
+ ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=8M6tC5kIUus-wbMEKDSMbCLnsobs6rgbujycsmhYa5g,2807
166
168
  ai_edge_torch/generative/layers/experimental/types.py,sha256=gZI9hIPB3XAo4oecKIIoVDfiyibLaSNFhecPFx4VDTM,2913
167
169
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
168
170
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
@@ -203,7 +205,7 @@ ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrss
203
205
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
204
206
  ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
205
207
  ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
206
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
208
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
207
209
  ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
208
210
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
209
211
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
@@ -223,17 +225,17 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNi
223
225
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
224
226
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
225
227
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
226
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=Jq8_yAxC7ilzd6tOaRyBsOUEeenFF_EAC5haacZT4Pg,10247
228
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=4syWstepGiw3IKa8O7lciXywY7RFJ7OCWFMU1Lg3h-s,10777
227
229
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
228
230
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
229
231
  ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
230
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=h-YHW7xmvt9dpea-7Zj82HW7h5TKzW6GBEE13dIJQ40,11518
232
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=JRGLXW8EQ1L-vdiVTkD1kb4AnTU05eRwZ7Ke010hZmg,11473
231
233
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
232
234
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
233
235
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
234
236
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
235
237
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
236
- ai_edge_torch/odml_torch/lowerings/utils.py,sha256=-TzK1igPgR38oZkU1iPh-DZhlKVwuBtGWVC-y81PXzY,8935
238
+ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=uJaFbbgvYMI4-VFpFcMpaObNfBQl6nV0x8Yo8LaSAOE,8974
237
239
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
238
240
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
239
241
  ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
@@ -243,8 +245,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
243
245
  ai_edge_torch/testing/export.py,sha256=dguMa-aEi-WDPnmGBUs2IPdEmt2IVmHOELH19uiJ1uU,3014
244
246
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
245
247
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
246
- ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
247
- ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/METADATA,sha256=kZwo6E79HLuM7_4E-Yw9erTzOnAAzio3Vy45hXNiC48,2051
248
- ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
249
- ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
250
- ai_edge_torch_nightly-0.5.0.dev20250409.dist-info/RECORD,,
248
+ ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
249
+ ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/METADATA,sha256=8m6hxUmTT0arSNdEuo-mOyg1w9T3ekAvedvY2T6Opgw,2051
250
+ ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
251
+ ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
252
+ ai_edge_torch_nightly-0.5.0.dev20250410.dist-info/RECORD,,