ai-edge-torch-nightly 0.3.0.dev20241122__py3-none-any.whl → 0.3.0.dev20241123__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -74,6 +74,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
74
74
  dim,
75
75
  eps=config.epsilon,
76
76
  zero_centered_gamma=config.zero_centered,
77
+ enable_hlfb=config.enable_hlfb,
77
78
  )
78
79
  elif config.type == cfg.NormalizationType.LAYER_NORM:
79
80
  return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
@@ -23,15 +23,24 @@ import torch.nn.functional as F
23
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
24
24
  class RMSNorm(torch.nn.Module):
25
25
 
26
- def __init__(self, dim: int, eps: float = 1e-6, zero_centered_gamma=False):
26
+ def __init__(
27
+ self,
28
+ dim: int,
29
+ eps: float = 1e-6,
30
+ zero_centered_gamma=False,
31
+ enable_hlfb: bool = False,
32
+ ):
27
33
  """Initialize the RMSNorm layer.
28
34
 
29
35
  Args:
30
36
  dim (int): dimension of the input tensor.
31
37
  eps (float): A small float value to ensure numerical stability (default:
32
38
  1e-6).
39
+ zero_centered_gamma (bool): Whether or not gamma has an offset.
40
+ enable_hlfb (bool): use HLFB in the op.
33
41
  """
34
42
  super().__init__()
43
+ self.enable_hlfb = enable_hlfb
35
44
  self.eps = eps
36
45
  self.weight = torch.nn.Parameter(torch.ones(dim))
37
46
  self.zero_centered_gamma = zero_centered_gamma
@@ -56,12 +65,20 @@ class RMSNorm(torch.nn.Module):
56
65
  Returns:
57
66
  torch.Tensor: output tensor after applying RMSNorm.
58
67
  """
59
- output = self._norm(x.float()).type_as(x)
60
68
  if self.zero_centered_gamma:
61
- return output * (1 + self.weight)
69
+ w = 1 + self.weight
62
70
  else:
63
- return output * self.weight
71
+ w = self.weight
64
72
 
73
+ if self.enable_hlfb:
74
+ return rms_norm_with_hlfb(
75
+ x,
76
+ w,
77
+ self.eps,
78
+ )
79
+ else:
80
+ output = self._norm(x.float()).type_as(x)
81
+ return output * w
65
82
 
66
83
  class GroupNorm(torch.nn.Module):
67
84
 
@@ -194,6 +211,37 @@ def group_norm_with_hlfb(
194
211
  return y
195
212
 
196
213
 
214
+ def rms_norm_with_hlfb(
215
+ x: torch.Tensor,
216
+ w: torch.Tensor,
217
+ eps: float,
218
+ ):
219
+ """RMS Normalization with high-level function boundary enabled.
220
+
221
+ Args:
222
+ x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
223
+ w (torch.Tensor): The learned parameter tensor for normalization.
224
+ eps (float): A small float value to ensure numerical stability.
225
+
226
+ Returns:
227
+ The output tensor of RMS Normalization.
228
+ """
229
+ builder = StableHLOCompositeBuilder(
230
+ name="odml.rms_norm", attr={"epsilon": eps}
231
+ )
232
+
233
+ x, w = builder.mark_inputs(x, w)
234
+
235
+ def _norm(x):
236
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
237
+
238
+ output = _norm(x.float()).type_as(x)
239
+ out = output * w
240
+
241
+ out = builder.mark_outputs(out)
242
+ return out
243
+
244
+
197
245
  def layer_norm_with_hlfb(
198
246
  x: torch.Tensor,
199
247
  normalized_shape: list[int],
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241122"
16
+ __version__ = "0.3.0.dev20241123"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241122
3
+ Version: 0.3.0.dev20241123
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=B4r6opjqsPmDJdLbwvWto6dM-0KbsjszxSL6CXmi8K8,706
6
+ ai_edge_torch/version.py,sha256=H-20IqCwUbnpEuOveG8xD4bixK8svU7k-n0hfdJ8AoY,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -115,11 +115,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
115
115
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
116
116
  ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
117
117
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
118
- ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
118
+ ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
119
119
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
120
120
  ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
121
121
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
122
- ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
122
+ ai_edge_torch/generative/layers/normalization.py,sha256=_2hps2m2MXEHQWbM-1B4he90hbq8wqOnIDIf-qXHhpc,7589
123
123
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
124
124
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
125
125
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -196,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
196
196
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
197
197
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
198
198
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
199
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/METADATA,sha256=-YpC-ksRKR8hJ8pZET4Q2F5KbUiRmGOXPhBoEQgIuOA,1897
201
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
202
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
- ai_edge_torch_nightly-0.3.0.dev20241122.dist-info/RECORD,,
199
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
200
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/METADATA,sha256=wsp1IYSM424I9ovzuvu-4b_T0VAcCcABGZuChuCdzWM,1897
201
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
202
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
203
+ ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.0)
2
+ Generator: bdist_wheel (0.45.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5