ai-edge-torch-nightly 0.6.0.dev20250822__py3-none-any.whl → 0.6.0.dev20250823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/layers/builder.py +6 -1
- ai_edge_torch/generative/layers/normalization.py +6 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250822.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250822.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.6.0.dev20250822.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250822.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250822.dist-info → ai_edge_torch_nightly-0.6.0.dev20250823.dist-info}/top_level.txt +0 -0
@@ -54,7 +54,11 @@ class SwiGLU(nn.Module):
|
|
54
54
|
return F.silu(x) * y
|
55
55
|
|
56
56
|
|
57
|
-
def build_norm(
|
57
|
+
def build_norm(
|
58
|
+
dim: int,
|
59
|
+
config: cfg.NormalizationConfig,
|
60
|
+
init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
|
61
|
+
):
|
58
62
|
"""Builder function for normalizers.
|
59
63
|
|
60
64
|
Args:
|
@@ -77,6 +81,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
77
81
|
with_scale=config.with_scale,
|
78
82
|
scale_shift=config.scale_shift,
|
79
83
|
enable_hlfb=config.enable_hlfb,
|
84
|
+
init_fn=init_fn,
|
80
85
|
)
|
81
86
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
82
87
|
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
|
@@ -14,6 +14,8 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
# Common normalization layers.
|
16
16
|
|
17
|
+
from typing import Callable
|
18
|
+
|
17
19
|
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
18
20
|
import torch
|
19
21
|
from torch import nn
|
@@ -31,6 +33,7 @@ class RMSNorm(torch.nn.Module):
|
|
31
33
|
with_scale: bool = False,
|
32
34
|
scale_shift: float = 1.0,
|
33
35
|
enable_hlfb: bool = False,
|
36
|
+
init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
|
34
37
|
):
|
35
38
|
"""Initialize the RMSNorm layer.
|
36
39
|
|
@@ -42,12 +45,15 @@ class RMSNorm(torch.nn.Module):
|
|
42
45
|
with_scale (bool): Whether or not to use a scale parameter.
|
43
46
|
scale_shift (float): The shift to apply to the scale parameter.
|
44
47
|
enable_hlfb (bool): use HLFB in the op.
|
48
|
+
init_fn: The initialization function to use for the parameters. This is
|
49
|
+
used to initialize the scale parameter.
|
45
50
|
"""
|
46
51
|
super().__init__()
|
47
52
|
self.dim = dim
|
48
53
|
self.enable_hlfb = enable_hlfb
|
49
54
|
self.eps = eps
|
50
55
|
self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
|
56
|
+
init_fn(self.weight)
|
51
57
|
self.zero_centered_gamma = zero_centered_gamma
|
52
58
|
self.with_scale = with_scale
|
53
59
|
if with_scale:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250823
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=JyfBN0yvWEvYs121XsXOTiW8p-4px3FWXcaDmfbzVlY,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -173,7 +173,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6i
|
|
173
173
|
ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
|
174
174
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
|
175
175
|
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
|
176
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
176
|
+
ai_edge_torch/generative/layers/builder.py,sha256=UiLOyvwd-bc0n5XcbVxi6JCn_qiKSC6zrDKSZT_TSDA,5030
|
177
177
|
ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
|
178
178
|
ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
|
179
179
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
|
@@ -181,7 +181,7 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1
|
|
181
181
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=A0IFXZ1HD2ZHOWRLfsDO4almgE0KQfjyBOdBFZIGnAs,10893
|
182
182
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
183
183
|
ai_edge_torch/generative/layers/model_config.py,sha256=HP-vu1UmAiTmdLlTyZGDUF3le0gji8a61mLCy966NZw,10261
|
184
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
184
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=WAhPcLbcC3SCEa6oIgdsojvN306_S8d90WyMQ7ZVP6I,7269
|
185
185
|
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
186
186
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
187
187
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
|
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
269
269
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
270
270
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
271
271
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
272
|
-
ai_edge_torch_nightly-0.6.0.
|
273
|
-
ai_edge_torch_nightly-0.6.0.
|
274
|
-
ai_edge_torch_nightly-0.6.0.
|
275
|
-
ai_edge_torch_nightly-0.6.0.
|
276
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/METADATA,sha256=ZrMG85KwcLQee8_Nj9jD25avEdJBJQ8zK4o3m7xFifk,2074
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
275
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
276
|
+
ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/RECORD,,
|
File without changes
|
File without changes
|