ai-edge-torch-nightly 0.6.0.dev20250821__py3-none-any.whl → 0.6.0.dev20250823__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -54,7 +54,11 @@ class SwiGLU(nn.Module):
54
54
  return F.silu(x) * y
55
55
 
56
56
 
57
- def build_norm(dim: int, config: cfg.NormalizationConfig):
57
+ def build_norm(
58
+ dim: int,
59
+ config: cfg.NormalizationConfig,
60
+ init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
61
+ ):
58
62
  """Builder function for normalizers.
59
63
 
60
64
  Args:
@@ -77,6 +81,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
77
81
  with_scale=config.with_scale,
78
82
  scale_shift=config.scale_shift,
79
83
  enable_hlfb=config.enable_hlfb,
84
+ init_fn=init_fn,
80
85
  )
81
86
  elif config.type == cfg.NormalizationType.LAYER_NORM:
82
87
  return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
@@ -14,6 +14,8 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from typing import Callable
18
+
17
19
  from ai_edge_torch.hlfb import StableHLOCompositeBuilder
18
20
  import torch
19
21
  from torch import nn
@@ -31,6 +33,7 @@ class RMSNorm(torch.nn.Module):
31
33
  with_scale: bool = False,
32
34
  scale_shift: float = 1.0,
33
35
  enable_hlfb: bool = False,
36
+ init_fn: Callable[..., torch.Tensor] = lambda *args, **kwargs: None,
34
37
  ):
35
38
  """Initialize the RMSNorm layer.
36
39
 
@@ -42,12 +45,15 @@ class RMSNorm(torch.nn.Module):
42
45
  with_scale (bool): Whether or not to use a scale parameter.
43
46
  scale_shift (float): The shift to apply to the scale parameter.
44
47
  enable_hlfb (bool): use HLFB in the op.
48
+ init_fn: The initialization function to use for the parameters. This is
49
+ used to initialize the scale parameter.
45
50
  """
46
51
  super().__init__()
47
52
  self.dim = dim
48
53
  self.enable_hlfb = enable_hlfb
49
54
  self.eps = eps
50
55
  self.weight = torch.nn.Parameter(torch.ones(dim), requires_grad=False)
56
+ init_fn(self.weight)
51
57
  self.zero_centered_gamma = zero_centered_gamma
52
58
  self.with_scale = with_scale
53
59
  if with_scale:
@@ -124,9 +124,11 @@ class LoweringInterpreter(torch.fx.Interpreter):
124
124
  def run_node(self, node: torch.fx.Node):
125
125
  loc = self._build_loc(node)
126
126
  with loc:
127
- self.lctx = self.lctx.replace(ir_location=loc, node=node)
127
+ self.lctx.ir_location = loc
128
+ self.lctx.node = node
128
129
  res = super().run_node(node)
129
- self.lctx = self.lctx.replace(ir_location=None, node=None)
130
+ self.lctx.ir_location = None
131
+ self.lctx.node = None
130
132
  return res
131
133
 
132
134
  def call_function(self, target, args, kwargs):
@@ -63,7 +63,10 @@ def inline(
63
63
  while True:
64
64
  is_changed = False
65
65
  for op in block.operations:
66
- if op.OPERATION_NAME != func.CallOp.OPERATION_NAME:
66
+ if (
67
+ not hasattr(op, "OPERATION_NAME")
68
+ or op.OPERATION_NAME != func.CallOp.OPERATION_NAME
69
+ ):
67
70
  continue
68
71
 
69
72
  call_op = cast(func.CallOp, op)
@@ -96,7 +99,10 @@ def clone_func_body_ops(func_op: func.FuncOp, ir_inputs: Sequence[ir.Value]):
96
99
 
97
100
  for op in list(func_op.entry_block.operations):
98
101
  cloned_operands = [value_mapping[val] for val in op.operands]
99
- if op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME:
102
+ if (
103
+ hasattr(op, "OPERATION_NAME")
104
+ and op.OPERATION_NAME == func.ReturnOp.OPERATION_NAME
105
+ ):
100
106
  return cloned_operands
101
107
 
102
108
  cloned = cast(ir.Operation, op.operation.clone())
@@ -37,6 +37,3 @@ class LoweringContext:
37
37
  def loc(self):
38
38
  """Shortcut for ir_location."""
39
39
  return self.ir_location
40
-
41
- def replace(self, **kwargs):
42
- return dataclasses.replace(self, **kwargs)
ai_edge_torch/version.py CHANGED
@@ -15,4 +15,4 @@
15
15
 
16
16
  # The next version of ai-edge-torch.
17
17
  # The minor version code should be bumped after every release.
18
- __version__ = "0.6.0.dev20250821"
18
+ __version__ = "0.6.0.dev20250823"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.6.0.dev20250821
3
+ Version: 0.6.0.dev20250823
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
2
2
  ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
5
- ai_edge_torch/version.py,sha256=KCQW3wwoSg8zjssrnlyquD4aO-OeN7u1jWFHk_-trYY,806
5
+ ai_edge_torch/version.py,sha256=JyfBN0yvWEvYs121XsXOTiW8p-4px3FWXcaDmfbzVlY,806
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -173,7 +173,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6i
173
173
  ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
174
174
  ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
175
175
  ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
176
- ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
176
+ ai_edge_torch/generative/layers/builder.py,sha256=UiLOyvwd-bc0n5XcbVxi6JCn_qiKSC6zrDKSZT_TSDA,5030
177
177
  ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
178
178
  ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
179
179
  ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
@@ -181,7 +181,7 @@ ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1
181
181
  ai_edge_torch/generative/layers/kv_cache.py,sha256=A0IFXZ1HD2ZHOWRLfsDO4almgE0KQfjyBOdBFZIGnAs,10893
182
182
  ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
183
183
  ai_edge_torch/generative/layers/model_config.py,sha256=HP-vu1UmAiTmdLlTyZGDUF3le0gji8a61mLCy966NZw,10261
184
- ai_edge_torch/generative/layers/normalization.py,sha256=ijwCpi22NLX-Sygwy5sK9l9WjGvbPIhZvVwoBAonWAo,7014
184
+ ai_edge_torch/generative/layers/normalization.py,sha256=WAhPcLbcC3SCEa6oIgdsojvN306_S8d90WyMQ7ZVP6I,7269
185
185
  ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
186
186
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
187
187
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=2_AgwENsaOgaxgiSqgoj0V0JzQ09dFtP_nBhX-lJK2g,5648
@@ -235,8 +235,8 @@ ai_edge_torch/lowertools/translate_recipe.py,sha256=Oavs0ENKVnIeB-WidXvokTPqNlFf
235
235
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
236
236
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
237
237
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
238
- ai_edge_torch/odml_torch/export.py,sha256=FDseiAOOcgN8HOuwv6HT9sALZgXGj7vZi6kwZPWPF14,14935
239
- ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
238
+ ai_edge_torch/odml_torch/export.py,sha256=WPsATMFsc62aETExBsb1v6Ec1fL3r0CEvxwE7r2C_CA,14931
239
+ ai_edge_torch/odml_torch/export_utils.py,sha256=Eax4QUefIzpmVuQxo1y9FqJ6g0qXjg4C0IVZ5uYPscs,4899
240
240
  ai_edge_torch/odml_torch/optimization_barrier.py,sha256=2lmSiu5iXWLFWpupZHvsVeNYNzG5AVGSK3K_CNhS5Sk,2290
241
241
  ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
242
242
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -257,7 +257,7 @@ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=xUuQjoR0NJhuwG36Guyc
257
257
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
258
258
  ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
259
259
  ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
260
- ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
260
+ ai_edge_torch/odml_torch/lowerings/context.py,sha256=dhGS9oFjKLKvc-aBFbNUhb4LqggP_AmPrgpT1EnOw0w,1216
261
261
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
262
262
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=YagY2IDqtJTpPxjPz2dhb3eyCFTpTSu3ptEPSvEuDtk,10574
263
263
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
@@ -269,8 +269,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
269
269
  ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
270
270
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
271
271
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
272
- ai_edge_torch_nightly-0.6.0.dev20250821.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
- ai_edge_torch_nightly-0.6.0.dev20250821.dist-info/METADATA,sha256=HNSeR4DnS8ax0bUi9CSIOzakw0lJk1ZA8JM889_0rSg,2074
274
- ai_edge_torch_nightly-0.6.0.dev20250821.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
- ai_edge_torch_nightly-0.6.0.dev20250821.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
- ai_edge_torch_nightly-0.6.0.dev20250821.dist-info/RECORD,,
272
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
273
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/METADATA,sha256=ZrMG85KwcLQee8_Nj9jD25avEdJBJQ8zK4o3m7xFifk,2074
274
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
275
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
276
+ ai_edge_torch_nightly-0.6.0.dev20250823.dist-info/RECORD,,