ai-edge-torch-nightly 0.3.0.dev20241122__py3-none-any.whl → 0.3.0.dev20241123__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/layers/builder.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +52 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241122.dist-info → ai_edge_torch_nightly-0.3.0.dev20241123.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
74
74
|
dim,
|
75
75
|
eps=config.epsilon,
|
76
76
|
zero_centered_gamma=config.zero_centered,
|
77
|
+
enable_hlfb=config.enable_hlfb,
|
77
78
|
)
|
78
79
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
79
80
|
return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
|
@@ -23,15 +23,24 @@ import torch.nn.functional as F
|
|
23
23
|
# Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
|
24
24
|
class RMSNorm(torch.nn.Module):
|
25
25
|
|
26
|
-
def __init__(
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
dim: int,
|
29
|
+
eps: float = 1e-6,
|
30
|
+
zero_centered_gamma=False,
|
31
|
+
enable_hlfb: bool = False,
|
32
|
+
):
|
27
33
|
"""Initialize the RMSNorm layer.
|
28
34
|
|
29
35
|
Args:
|
30
36
|
dim (int): dimension of the input tensor.
|
31
37
|
eps (float): A small float value to ensure numerical stability (default:
|
32
38
|
1e-6).
|
39
|
+
zero_centered_gamma (bool): Whether or not gamma has an offset.
|
40
|
+
enable_hlfb (bool): use HLFB in the op.
|
33
41
|
"""
|
34
42
|
super().__init__()
|
43
|
+
self.enable_hlfb = enable_hlfb
|
35
44
|
self.eps = eps
|
36
45
|
self.weight = torch.nn.Parameter(torch.ones(dim))
|
37
46
|
self.zero_centered_gamma = zero_centered_gamma
|
@@ -56,12 +65,20 @@ class RMSNorm(torch.nn.Module):
|
|
56
65
|
Returns:
|
57
66
|
torch.Tensor: output tensor after applying RMSNorm.
|
58
67
|
"""
|
59
|
-
output = self._norm(x.float()).type_as(x)
|
60
68
|
if self.zero_centered_gamma:
|
61
|
-
|
69
|
+
w = 1 + self.weight
|
62
70
|
else:
|
63
|
-
|
71
|
+
w = self.weight
|
64
72
|
|
73
|
+
if self.enable_hlfb:
|
74
|
+
return rms_norm_with_hlfb(
|
75
|
+
x,
|
76
|
+
w,
|
77
|
+
self.eps,
|
78
|
+
)
|
79
|
+
else:
|
80
|
+
output = self._norm(x.float()).type_as(x)
|
81
|
+
return output * w
|
65
82
|
|
66
83
|
class GroupNorm(torch.nn.Module):
|
67
84
|
|
@@ -194,6 +211,37 @@ def group_norm_with_hlfb(
|
|
194
211
|
return y
|
195
212
|
|
196
213
|
|
214
|
+
def rms_norm_with_hlfb(
|
215
|
+
x: torch.Tensor,
|
216
|
+
w: torch.Tensor,
|
217
|
+
eps: float,
|
218
|
+
):
|
219
|
+
"""RMS Normalization with high-level function boundary enabled.
|
220
|
+
|
221
|
+
Args:
|
222
|
+
x (torch.Tensor): Input tensor for RMS Normalization, with BCHW shape.
|
223
|
+
w (torch.Tensor): The learned parameter tensor for normalization.
|
224
|
+
eps (float): A small float value to ensure numerical stability.
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
The output tensor of RMS Normalization.
|
228
|
+
"""
|
229
|
+
builder = StableHLOCompositeBuilder(
|
230
|
+
name="odml.rms_norm", attr={"epsilon": eps}
|
231
|
+
)
|
232
|
+
|
233
|
+
x, w = builder.mark_inputs(x, w)
|
234
|
+
|
235
|
+
def _norm(x):
|
236
|
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
237
|
+
|
238
|
+
output = _norm(x.float()).type_as(x)
|
239
|
+
out = output * w
|
240
|
+
|
241
|
+
out = builder.mark_outputs(out)
|
242
|
+
return out
|
243
|
+
|
244
|
+
|
197
245
|
def layer_norm_with_hlfb(
|
198
246
|
x: torch.Tensor,
|
199
247
|
normalized_shape: list[int],
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241123
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=H-20IqCwUbnpEuOveG8xD4bixK8svU7k-n0hfdJ8AoY,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -115,11 +115,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
115
115
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
116
116
|
ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
|
117
117
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
118
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
119
119
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
120
120
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
121
121
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
122
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
122
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=_2hps2m2MXEHQWbM-1B4he90hbq8wqOnIDIf-qXHhpc,7589
|
123
123
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
124
124
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
125
125
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -196,8 +196,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
196
196
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
197
197
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
198
198
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
201
|
-
ai_edge_torch_nightly-0.3.0.
|
202
|
-
ai_edge_torch_nightly-0.3.0.
|
203
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/METADATA,sha256=wsp1IYSM424I9ovzuvu-4b_T0VAcCcABGZuChuCdzWM,1897
|
201
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
202
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
203
|
+
ai_edge_torch_nightly-0.3.0.dev20241123.dist-info/RECORD,,
|
File without changes
|