ai-edge-torch-nightly 0.3.0.dev20241122__py3-none-any.whl → 0.3.0.dev20241123__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -74,6 +74,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
74
74
  dim,
75
75
  eps=config.epsilon,
76
76
  zero_centered_gamma=config.zero_centered,
77
+ enable_hlfb=config.enable_hlfb,
77
78
  )
78
79
  elif config.type == cfg.NormalizationType.LAYER_NORM:
79
80
  return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
@@ -23,15 +23,24 @@ import torch.nn.functional as F
23
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
24
24
  class RMSNorm(torch.nn.Module):
25
25
 
26
- def __init__(self, dim: int, eps: float = 1e-6, zero_centered_gamma=False):
26
+ def __init__(
27
+ self,
28
+ dim: int,
29
+ eps: float = 1e-6,
30
+ zero_centered_gamma=False,
31
+ enable_hlfb: bool = False,
32
+ ):
27
33
  """Initialize the RMSNorm layer.
28
34
 
29
35
  Args:
30
36
  dim (int): dimension of the input tensor.
31
37
  eps (float): A small float value to ensure numerical stability (default:
32
38
  1e-6).
39
+ zero_centered_gamma (bool): Whether or not gamma has an offset.
40
+ enable_hlfb (bool): use HLFB in the op.
33
41
  """
34
42
  super().__init__()
43
+ self.enable_hlfb = enable_hlfb
35
44
  self.eps = eps
36
45
  self.weight = torch.nn.Parameter(torch.ones(dim))
37
46
  self.zero_centered_gamma = zero_centered_gamma
@@ -56,12 +65,20 @@ class RMSNorm(torch.nn.Module):
56
65
  Returns:
57
66
  torch.Tensor: output tensor after applying RMSNorm.
58
67
  """
59
- output = self._norm(x.float()).type_as(x)
60
68
  if self.zero_centered_gamma:
61
- return output * (1 + self.weight)
69
+ w = 1 + self.weight
62
70
  else:
63
- return output * self.weight
71
+ w = self.weight
64
72
 
73
+ if self.enable_hlfb:
74
+ return rms_norm_with_hlfb(
75
+ x,
76
+ w,
77
+ self.eps,
78
+ )
79
+ else:
80
+ output = self._norm(x.float()).type_as(x)
81
+ return output * w
65
82
 
66
83
  class GroupNorm(torch.nn.Module):
67
84
 
@@ -194,6 +211,37 @@ def group_norm_with_hlfb(
194
211
  return y
195
212
 
196
213
 
214
+ def rms_norm_with_hlfb(
215
+ x: torch.Tensor,
216
+ w: torch.Tensor,
217
+ eps: float,
218
+ ):
219
+ """RMS Normalization with high-level function boundary enabled.
220
+
221
+ Args:
222
+ x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
223
+ w (torch.Tensor): The learned parameter tensor for normalization.
224
+ eps (float): A small float value to ensure numerical stability.
225
+
226
+ Returns:
227
+ The output tensor of RMS Normalization.
228
+ """
229
+ builder = StableHLOCompositeBuilder(
230
+ name="odml.rms_norm", attr={"epsilon": eps}
231
+ )
232
+
233
+ x, w = builder.mark_inputs(x, w)
234
+
235
+ def _norm(x):
236
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
237
+
238
+ output = _norm(x.float()).type_as(x)
239
+ out = output * w
240
+
241
+ out = builder.mark_outputs(out)
242
+ return out
243
+
244
+
197
245
  def layer_norm_with_hlfb(
198
246
  x: torch.Tensor,
199
247
  normalized_shape: list[int],
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241122"
16
+ __version__ = "0.3.0.dev20241123"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241122
3
+ Version: 0.3.0.dev20241123
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=B4r6opjqsPmDJdLbwvWto6dM-0KbsjszxSL6CXmi8K8,706
6
+ ai_edge_torch/version.py,sha256=H-20IqCwUbnpEuOveG8xD4bixK8svU7k-n0hfdJ8AoY,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -115,11 +115,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
115
115
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
116
116
  ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
117
117
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
118
- ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
118
+ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
119
119
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
120
120
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
121
121
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
122
- ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
122
+ ai_edge_torch/generative/layers/normalization.py,sha256=_2hps2m2MXEHQWbM-1B4he90hbq8wqOnIDIf-qXHhpc,7589
123
123
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
124
124
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
125
125
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -196,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
196
196
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
197
197
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
198
198
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
199
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/METADATA,sha256=-YpC-ksRKR8hJ8pZET4Q2F5KbUiRmGOXPhBoEQgIuOA,1897
201
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
202
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/RECORD,,
199
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/METADATA,sha256=wsp1IYSM424I9ovzuvu-4b_T0VAcCcABGZuChuCdzWM,1897
201
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
202
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.0)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5