ai-edge-torch-nightly 0.3.0.dev20241122__py3-none-any.whl → 0.3.0.dev20241123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/layers/builder.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +52 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
74
74
|
dim,
|
75
75
|
eps=config.epsilon,
|
76
76
|
zero_centered_gamma=config.zero_centered,
|
77
|
+
enable_hlfb=config.enable_hlfb,
|
77
78
|
)
|
78
79
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
79
80
|
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
|
@@ -23,15 +23,24 @@ import torch.nn.functional as F
|
|
23
23
|
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
|
24
24
|
class RMSNorm(torch.nn.Module):
|
25
25
|
|
26
|
-
def __init__(
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
dim: int,
|
29
|
+
eps: float = 1e-6,
|
30
|
+
zero_centered_gamma=False,
|
31
|
+
enable_hlfb: bool = False,
|
32
|
+
):
|
27
33
|
"""Initialize the RMSNorm layer.
|
28
34
|
|
29
35
|
Args:
|
30
36
|
dim (int): dimension of the input tensor.
|
31
37
|
eps (float): A small float value to ensure numerical stability (default:
|
32
38
|
1e-6).
|
39
|
+
zero_centered_gamma (bool): Whether or not gamma has an offset.
|
40
|
+
enable_hlfb (bool): use HLFB in the op.
|
33
41
|
"""
|
34
42
|
super().__init__()
|
43
|
+
self.enable_hlfb = enable_hlfb
|
35
44
|
self.eps = eps
|
36
45
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
37
46
|
self.zero_centered_gamma = zero_centered_gamma
|
@@ -56,12 +65,20 @@ class RMSNorm(torch.nn.Module):
|
|
56
65
|
Returns:
|
57
66
|
torch.Tensor: output tensor after applying RMSNorm.
|
58
67
|
"""
|
59
|
-
output = self._norm(x.float()).type_as(x)
|
60
68
|
if self.zero_centered_gamma:
|
61
|
-
|
69
|
+
w = 1 + self.weight
|
62
70
|
else:
|
63
|
-
|
71
|
+
w = self.weight
|
64
72
|
|
73
|
+
if self.enable_hlfb:
|
74
|
+
return rms_norm_with_hlfb(
|
75
|
+
x,
|
76
|
+
w,
|
77
|
+
self.eps,
|
78
|
+
)
|
79
|
+
else:
|
80
|
+
output = self._norm(x.float()).type_as(x)
|
81
|
+
return output * w
|
65
82
|
|
66
83
|
class GroupNorm(torch.nn.Module):
|
67
84
|
|
@@ -194,6 +211,37 @@ def group_norm_with_hlfb(
|
|
194
211
|
return y
|
195
212
|
|
196
213
|
|
214
|
+
def rms_norm_with_hlfb(
|
215
|
+
x: torch.Tensor,
|
216
|
+
w: torch.Tensor,
|
217
|
+
eps: float,
|
218
|
+
):
|
219
|
+
"""RMS Normalization with high-level function boundary enabled.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
|
223
|
+
w (torch.Tensor): The learned parameter tensor for normalization.
|
224
|
+
eps (float): A small float value to ensure numerical stability.
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
The output tensor of RMS Normalization.
|
228
|
+
"""
|
229
|
+
builder = StableHLOCompositeBuilder(
|
230
|
+
name="odml.rms_norm", attr={"epsilon": eps}
|
231
|
+
)
|
232
|
+
|
233
|
+
x, w = builder.mark_inputs(x, w)
|
234
|
+
|
235
|
+
def _norm(x):
|
236
|
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
237
|
+
|
238
|
+
output = _norm(x.float()).type_as(x)
|
239
|
+
out = output * w
|
240
|
+
|
241
|
+
out = builder.mark_outputs(out)
|
242
|
+
return out
|
243
|
+
|
244
|
+
|
197
245
|
def layer_norm_with_hlfb(
|
198
246
|
x: torch.Tensor,
|
199
247
|
normalized_shape: list[int],
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241123
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=H-20IqCwUbnpEuOveG8xD4bixK8svU7k-n0hfdJ8AoY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -115,11 +115,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
115
115
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
116
116
|
ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
|
117
117
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
118
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
119
119
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
120
120
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
121
121
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
122
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
122
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=_2hps2m2MXEHQWbM-1B4he90hbq8wqOnIDIf-qXHhpc,7589
|
123
123
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
124
124
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
125
125
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -196,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
196
196
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
197
197
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
198
198
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
201
|
-
ai_edge_torch_nightly-0.3.0.
|
202
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/METADATA,sha256=wsp1IYSM424I9ovzuvu-4b_T0VAcCcABGZuChuCdzWM,1897
|
201
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
202
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/RECORD,,
|
File without changes
|