ai-edge-torch-nightly 0.6.0.dev20250821__py3-none-any.whl → 0.6.0.dev20250823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/layers/builder.py +6 -1
- ai_edge_torch/generative/layers/normalization.py +6 -0
- ai_edge_torch/odml_torch/export.py +4 -2
- ai_edge_torch/odml_torch/export_utils.py +8 -2
- ai_edge_torch/odml_torch/lowerings/context.py +0 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250821.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250821.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.6.0.dev20250821.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250821.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250821.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,11 @@ class SwiGLU(nn.Module):
|
|
54
54
|
return F.silu(x) * y
|
55
55
|
|
56
56
|
|
57
|
-
def build_norm(
|
57
|
+
def build_norm(
|
58
|
+
dim: int,
|
59
|
+
config: cfg.NormalizationConfig,
|
60
|
+
init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
|
61
|
+
):
|
58
62
|
"""Builder function for normalizers.
|
59
63
|
|
60
64
|
Args:
|
@@ -77,6 +81,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
77
81
|
with_scale=config.with_scale,
|
78
82
|
scale_shift=config.scale_shift,
|
79
83
|
enable_hlfb=config.enable_hlfb,
|
84
|
+
init_fn=init_fn,
|
80
85
|
)
|
81
86
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
82
87
|
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common normalization layers.
|
16
16
|
|
17
|
+
from typing import Callable
|
18
|
+
|
17
19
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
18
20
|
import torch
|
19
21
|
from torch import nn
|
@@ -31,6 +33,7 @@ class RMSNorm(torch.nn.Module):
|
|
31
33
|
with_scale: bool = False,
|
32
34
|
scale_shift: float = 1.0,
|
33
35
|
enable_hlfb: bool = False,
|
36
|
+
init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
|
34
37
|
):
|
35
38
|
"""Initialize the RMSNorm layer.
|
36
39
|
|
@@ -42,12 +45,15 @@ class RMSNorm(torch.nn.Module):
|
|
42
45
|
with_scale (bool): Whether or not to use a scale parameter.
|
43
46
|
scale_shift (float): The shift to apply to the scale parameter.
|
44
47
|
enable_hlfb (bool): use HLFB in the op.
|
48
|
+
init_fn: The initialization function to use for the parameters. This is
|
49
|
+
used to initialize the scale parameter.
|
45
50
|
"""
|
46
51
|
super().__init__()
|
47
52
|
self.dim = dim
|
48
53
|
self.enable_hlfb = enable_hlfb
|
49
54
|
self.eps = eps
|
50
55
|
self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
|
56
|
+
init_fn(self.weight)
|
51
57
|
self.zero_centered_gamma = zero_centered_gamma
|
52
58
|
self.with_scale = with_scale
|
53
59
|
if with_scale:
|
@@ -124,9 +124,11 @@ class LoweringInterpreter(torch.fx.Interpreter):
|
|
124
124
|
def run_node(self, node: torch.fx.Node):
|
125
125
|
loc = self._build_loc(node)
|
126
126
|
with loc:
|
127
|
-
self.lctx =
|
127
|
+
self.lctx.ir_location = loc
|
128
|
+
self.lctx.node = node
|
128
129
|
res = super().run_node(node)
|
129
|
-
self.lctx
|
130
|
+
self.lctx.ir_location = None
|
131
|
+
self.lctx.node = None
|
130
132
|
return res
|
131
133
|
|
132
134
|
def call_function(self, target, args, kwargs):
|
@@ -63,7 +63,10 @@ def inline(
|
|
63
63
|
while True:
|
64
64
|
is_changed = False
|
65
65
|
for op in block.operations:
|
66
|
-
if
|
66
|
+
if (
|
67
|
+
not hasattr(op, "OPERATION_NAME")
|
68
|
+
or op.OPERATION_NAME != func.CallOp.OPERATION_NAME
|
69
|
+
):
|
67
70
|
continue
|
68
71
|
|
69
72
|
call_op = cast(func.CallOp, op)
|
@@ -96,7 +99,10 @@ def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
|
|
96
99
|
|
97
100
|
for op in list(func_op.entry_block.operations):
|
98
101
|
cloned_operands = [value_mapping[val] for val in op.operands]
|
99
|
-
if
|
102
|
+
if (
|
103
|
+
hasattr(op, "OPERATION_NAME")
|
104
|
+
and op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME
|
105
|
+
):
|
100
106
|
return cloned_operands
|
101
107
|
|
102
108
|
cloned = cast(ir.Operation, op.operation.clone())
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250823
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=JyfBN0yvWEvYs121XsXOTiW8p-4px3FWXcaDmfbzVlY,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -173,7 +173,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6i
|
|
173
173
|
ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
|
174
174
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
|
175
175
|
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
|
176
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
176
|
+
ai_edge_torch/generative/layers/builder.py,sha256=UiLOyvwd-bc0n5XcbVxi6JCn_qiKSC6zrDKSZT_TSDA,5030
|
177
177
|
ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
|
178
178
|
ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
|
179
179
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
|
@@ -181,7 +181,7 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1
|
|
181
181
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=A0IFXZ1HD2ZHOWRLfsDO4almgE0KQfjyBOdBFZIGnAs,10893
|
182
182
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
183
183
|
ai_edge_torch/generative/layers/model_config.py,sha256=HP-vu1UmAiTmdLlTyZGDUF3le0gji8a61mLCy966NZw,10261
|
184
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
184
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=WAhPcLbcC3SCEa6oIgdsojvN306_S8d90WyMQ7ZVP6I,7269
|
185
185
|
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
186
186
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
187
187
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
|
@@ -235,8 +235,8 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=Oavs0ENKVnIeB-WidXvokTPqNlFf
|
|
235
235
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
236
236
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
237
237
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
238
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
239
|
-
ai_edge_torch/odml_torch/export_utils.py,sha256=
|
238
|
+
ai_edge_torch/odml_torch/export.py,sha256=WPsATMFsc62aETExBsb1v6Ec1fL3r0CEvxwE7r2C_CA,14931
|
239
|
+
ai_edge_torch/odml_torch/export_utils.py,sha256=Eax4QUefIzpmVuQxo1y9FqJ6g0qXjg4C0IVZ5uYPscs,4899
|
240
240
|
ai_edge_torch/odml_torch/optimization_barrier.py,sha256=2lmSiu5iXWLFWpupZHvsVeNYNzG5AVGSK3K_CNhS5Sk,2290
|
241
241
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
242
242
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -257,7 +257,7 @@ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=xUuQjoR0NJhuwG36Guyc
|
|
257
257
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
258
258
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
259
259
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
260
|
-
ai_edge_torch/odml_torch/lowerings/context.py,sha256=
|
260
|
+
ai_edge_torch/odml_torch/lowerings/context.py,sha256=dhGS9oFjKLKvc-aBFbNUhb4LqggP_AmPrgpT1EnOw0w,1216
|
261
261
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
|
262
262
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=YagY2IDqtJTpPxjPz2dhb3eyCFTpTSu3ptEPSvEuDtk,10574
|
263
263
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
269
269
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
270
270
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
271
271
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
272
|
-
ai_edge_torch_nightly-0.6.0.
|
273
|
-
ai_edge_torch_nightly-0.6.0.
|
274
|
-
ai_edge_torch_nightly-0.6.0.
|
275
|
-
ai_edge_torch_nightly-0.6.0.
|
276
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/METADATA,sha256=ZrMG85KwcLQee8_Nj9jD25avEdJBJQ8zK4o3m7xFifk,2074
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
275
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
276
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/RECORD,,
|
File without changes
|
File without changes
|