ai-edge-torch-nightly 0.6.0.dev20250822__py3-none-any.whl → 0.6.0.dev20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,7 +54,11 @@ class SwiGLU(nn.Module):
54
54
  return F.silu(x) * y
55
55
 
56
56
 
57
- def build_norm(dim: int, config: cfg.NormalizationConfig):
57
+ def build_norm(
58
+ dim: int,
59
+ config: cfg.NormalizationConfig,
60
+ init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
61
+ ):
58
62
  """Builder function for normalizers.
59
63
 
60
64
  Args:
@@ -77,6 +81,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
77
81
  with_scale=config.with_scale,
78
82
  scale_shift=config.scale_shift,
79
83
  enable_hlfb=config.enable_hlfb,
84
+ init_fn=init_fn,
80
85
  )
81
86
  elif config.type == cfg.NormalizationType.LAYER_NORM:
82
87
  return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from typing import Callable
18
+
17
19
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
18
20
  import torch
19
21
  from torch import nn
@@ -31,6 +33,7 @@ class RMSNorm(torch.nn.Module):
31
33
  with_scale: bool = False,
32
34
  scale_shift: float = 1.0,
33
35
  enable_hlfb: bool = False,
36
+ init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
34
37
  ):
35
38
  """Initialize the RMSNorm layer.
36
39
 
@@ -42,12 +45,15 @@ class RMSNorm(torch.nn.Module):
42
45
  with_scale (bool): Whether or not to use a scale parameter.
43
46
  scale_shift (float): The shift to apply to the scale parameter.
44
47
  enable_hlfb (bool): use HLFB in the op.
48
+ init_fn: The initialization function to use for the parameters. This is
49
+ used to initialize the scale parameter.
45
50
  """
46
51
  super().__init__()
47
52
  self.dim = dim
48
53
  self.enable_hlfb = enable_hlfb
49
54
  self.eps = eps
50
55
  self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
56
+ init_fn(self.weight)
51
57
  self.zero_centered_gamma = zero_centered_gamma
52
58
  self.with_scale = with_scale
53
59
  if with_scale:
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250822"
18
+ __version__ = "0.6.0.dev20250823"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250822
3
+ Version: 0.6.0.dev20250823
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=3opXPBORBcfxsu8h1LE4OMIWY6oXA1l6xSP65lCMRog,806
5
+ ai_edge_torch/version.py,sha256=JyfBN0yvWEvYs121XsXOTiW8p-4px3FWXcaDmfbzVlY,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -173,7 +173,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6i
173
173
  ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
174
174
  ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
175
175
  ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
176
- ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
176
+ ai_edge_torch/generative/layers/builder.py,sha256=UiLOyvwd-bc0n5XcbVxi6JCn_qiKSC6zrDKSZT_TSDA,5030
177
177
  ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
178
178
  ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
179
179
  ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
@@ -181,7 +181,7 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1
181
181
  ai_edge_torch/generative/layers/kv_cache.py,sha256=A0IFXZ1HD2ZHOWRLfsDO4almgE0KQfjyBOdBFZIGnAs,10893
182
182
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
183
183
  ai_edge_torch/generative/layers/model_config.py,sha256=HP-vu1UmAiTmdLlTyZGDUF3le0gji8a61mLCy966NZw,10261
184
- ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
184
+ ai_edge_torch/generative/layers/normalization.py,sha256=WAhPcLbcC3SCEa6oIgdsojvN306_S8d90WyMQ7ZVP6I,7269
185
185
  ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
186
186
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
187
187
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
269
269
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
270
270
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
271
271
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
272
- ai_edge_torch_nightly-0.6.0.dev20250822.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
- ai_edge_torch_nightly-0.6.0.dev20250822.dist-info/METADATA,sha256=m65Q6FH51OLricsQOL0tRYM50M0Mp3NKniIlZqfdhhc,2074
274
- ai_edge_torch_nightly-0.6.0.dev20250822.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
- ai_edge_torch_nightly-0.6.0.dev20250822.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
- ai_edge_torch_nightly-0.6.0.dev20250822.dist-info/RECORD,,
272
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/METADATA,sha256=ZrMG85KwcLQee8_Nj9jD25avEdJBJQ8zK4o3m7xFifk,2074
274
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/RECORD,,