ai-edge-torch-nightly 0.3.0.dev20240921__py3-none-any.whl → 0.3.0.dev20240924__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -3
- ai_edge_torch/generative/examples/phi/verify.py +1 -0
- ai_edge_torch/generative/layers/builder.py +25 -24
- ai_edge_torch/generative/layers/model_config.py +3 -3
- ai_edge_torch/generative/layers/normalization.py +14 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/METADATA +2 -1
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/top_level.txt +0 -0
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
161
161
|
),
|
162
162
|
ff_config=cfg.FeedForwardConfig(
|
163
163
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
164
|
-
activation=cfg.ActivationConfig(
|
165
|
-
cfg.ActivationType.SILU_GLU, gate_is_front=True
|
166
|
-
),
|
164
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
167
165
|
intermediate_size=get_intermediate_size(idx),
|
168
166
|
pre_ff_norm_config=norm_config,
|
169
167
|
),
|
@@ -23,34 +23,35 @@ from torch import nn
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
|
25
25
|
|
26
|
-
|
27
|
-
|
28
|
-
) -> Callable[[torch.Tensor], torch.Tensor]:
|
29
|
-
"""Builds an activation function with GLU (Gated Linear Unit).
|
26
|
+
class GeGLU(nn.Module):
|
27
|
+
"""GeGLU is an activation function which is a variant of GELU.
|
30
28
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
f(x) = x * act(y),
|
35
|
-
where x is the first half of the input and y is the second half of the input.
|
29
|
+
GeGLU(x) = (xW+b) * GELU(xV+c)
|
30
|
+
See: https://arxiv.org/abs/2002.05202v1
|
31
|
+
"""
|
36
32
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
gate_is_front: whether the gate is in front half of the input. Other part is
|
41
|
-
the output in GLU.
|
33
|
+
def __init__(self, d_in: int, d_out: int):
|
34
|
+
super().__init__()
|
35
|
+
self.proj = nn.Linear(d_in, d_out * 2)
|
42
36
|
|
43
|
-
|
44
|
-
|
37
|
+
def forward(self, x: torch.Tensor):
|
38
|
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
39
|
+
return x * F.gelu(gate)
|
40
|
+
|
41
|
+
|
42
|
+
class SwiGLU(nn.Module):
|
43
|
+
"""SwiGLU is an activation function which is a variant of GLU.
|
44
|
+
|
45
|
+
SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
|
46
|
+
swish function.
|
47
|
+
|
48
|
+
SwiGLU(x) = Swish(xW+b) * (xV+c)
|
49
|
+
See: https://paperswithcode.com/method/swiglu
|
45
50
|
"""
|
46
51
|
|
47
|
-
def
|
52
|
+
def forward(self, x: torch.Tensor):
|
48
53
|
x, y = x.chunk(2, dim=-1)
|
49
|
-
|
50
|
-
return act(x) * y
|
51
|
-
return x * act(y)
|
52
|
-
|
53
|
-
return _glu
|
54
|
+
return F.silu(x) * y
|
54
55
|
|
55
56
|
|
56
57
|
def build_norm(dim: int, config: cfg.NormalizationConfig):
|
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
|
|
151
152
|
# See: https://github.com/hendrycks/GELUs
|
152
153
|
return lambda x: x * F.sigmoid(1.702 * x)
|
153
154
|
elif config.type == cfg.ActivationType.GE_GLU:
|
154
|
-
return
|
155
|
+
return GeGLU(config.dim_in, config.dim_out)
|
155
156
|
elif config.type == cfg.ActivationType.RELU:
|
156
157
|
return F.relu
|
157
158
|
elif config.type == cfg.ActivationType.SILU_GLU:
|
158
|
-
return
|
159
|
+
return SwiGLU()
|
159
160
|
else:
|
160
161
|
raise ValueError("Unsupported activation type.")
|
@@ -118,9 +118,9 @@ class AttentionConfig:
|
|
118
118
|
@dataclass
|
119
119
|
class ActivationConfig:
|
120
120
|
type: ActivationType = ActivationType.LINEAR
|
121
|
-
#
|
122
|
-
|
123
|
-
|
121
|
+
# Dimension of input and output, used in GeGLU.
|
122
|
+
dim_in: Optional[int] = None
|
123
|
+
dim_out: Optional[int] = None
|
124
124
|
|
125
125
|
|
126
126
|
@dataclass
|
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
|
|
183
183
|
"""
|
184
184
|
x = torch.permute(x, (0, 2, 3, 1))
|
185
185
|
|
186
|
+
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
187
|
+
# int32 when the bug is fixed.
|
186
188
|
builder = StableHLOCompositeBuilder(
|
187
|
-
name="odml.group_norm",
|
189
|
+
name="odml.group_norm",
|
190
|
+
attr={
|
191
|
+
"num_groups": num_groups,
|
192
|
+
"eps": eps,
|
193
|
+
"reduction_axes": 3,
|
194
|
+
"channel_axis": 3,
|
195
|
+
},
|
188
196
|
)
|
189
197
|
x, w, b = builder.mark_inputs(x, w, b)
|
190
198
|
x = torch.permute(x, (0, 3, 1, 2))
|
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
|
|
206
214
|
"""Layer Normalization with high-level function boundary enabled.
|
207
215
|
|
208
216
|
Args:
|
209
|
-
x (torch.Tensor): Input tensor for Layer Normalization.
|
217
|
+
x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
|
210
218
|
w (torch.Tensor): The weight tensor for the normalization.
|
211
219
|
b (torch.Tensor): The bias tensor for the normalization.
|
212
220
|
eps (float): A small float value to ensure numerical stability.
|
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
|
|
216
224
|
Returns:
|
217
225
|
The output tensor of Layer Normalization.
|
218
226
|
"""
|
219
|
-
builder = StableHLOCompositeBuilder(
|
227
|
+
builder = StableHLOCompositeBuilder(
|
228
|
+
name="odml.group_norm",
|
229
|
+
attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
|
230
|
+
)
|
220
231
|
x, w, b = builder.mark_inputs(x, w, b)
|
221
232
|
if use_input_shape:
|
222
233
|
normalized_shape = x.shape
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240924
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
+
Requires-Dist: ai-edge-litert-nightly
|
33
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
34
35
|
|
35
36
|
Library that supports converting PyTorch models into a .tflite format, which can
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=sQUcRP5rShDk3vfblz87j26JciN6PV8S8DJkiiZP5o8,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -48,12 +48,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
|
|
48
48
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
|
49
49
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
50
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
51
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
|
52
52
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
|
53
53
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
54
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
55
55
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
56
|
-
ai_edge_torch/generative/examples/phi/verify.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=SwPyRjiupD4AsmWW_7FDcMSWaNRmDBu6uVFcBQRoM40,2146
|
57
57
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
58
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
59
59
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
@@ -89,11 +89,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
89
89
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
90
90
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
91
91
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
92
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
92
|
+
ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
|
93
93
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
94
94
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
95
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
96
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
|
96
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=LDczSHujMgo1WV8IhYVQe-egPkaBEmWFt8wZQ_tgshg,6991
|
97
97
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
98
98
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
99
99
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -166,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
166
166
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
167
167
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
168
168
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
170
|
-
ai_edge_torch_nightly-0.3.0.
|
171
|
-
ai_edge_torch_nightly-0.3.0.
|
172
|
-
ai_edge_torch_nightly-0.3.0.
|
173
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/METADATA,sha256=BotYlw1pMxClnHOi8rSb5v6jX0zE7EqUo8b11xvqEII,1897
|
171
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
172
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
173
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/RECORD,,
|
File without changes
|
File without changes
|