ai-edge-torch-nightly 0.3.0.dev20240921__py3-none-any.whl → 0.3.0.dev20240924__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/openelm/openelm.py +1 -3
- ai_edge_torch/generative/examples/phi/verify.py +1 -0
- ai_edge_torch/generative/layers/builder.py +25 -24
- ai_edge_torch/generative/layers/model_config.py +3 -3
- ai_edge_torch/generative/layers/normalization.py +14 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/METADATA +2 -1
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240921.dist-info → ai_edge_torch_nightly-0.3.0.dev20240924.dist-info}/top_level.txt +0 -0
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
161
161
|
),
|
162
162
|
ff_config=cfg.FeedForwardConfig(
|
163
163
|
type=cfg.FeedForwardType.SEQUENTIAL,
|
164
|
-
activation=cfg.ActivationConfig(
|
165
|
-
cfg.ActivationType.SILU_GLU, gate_is_front=True
|
166
|
-
),
|
164
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
|
167
165
|
intermediate_size=get_intermediate_size(idx),
|
168
166
|
pre_ff_norm_config=norm_config,
|
169
167
|
),
|
@@ -23,34 +23,35 @@ from torch import nn
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
|
25
25
|
|
26
|
-
|
27
|
-
|
28
|
-
) -> Callable[[torch.Tensor], torch.Tensor]:
|
29
|
-
"""Builds an activation function with GLU (Gated Linear Unit).
|
26
|
+
class GeGLU(nn.Module):
|
27
|
+
"""GeGLU is an activation function which is a variant of GELU.
|
30
28
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
f(x) = x * act(y),
|
35
|
-
where x is the first half of the input and y is the second half of the input.
|
29
|
+
GeGLU(x) = (xW+b) * GELU(xV+c)
|
30
|
+
See: https://arxiv.org/abs/2002.05202v1
|
31
|
+
"""
|
36
32
|
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
gate_is_front: whether the gate is in front half of the input. Other part is
|
41
|
-
the output in GLU.
|
33
|
+
def __init__(self, d_in: int, d_out: int):
|
34
|
+
super().__init__()
|
35
|
+
self.proj = nn.Linear(d_in, d_out * 2)
|
42
36
|
|
43
|
-
|
44
|
-
|
37
|
+
def forward(self, x: torch.Tensor):
|
38
|
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
39
|
+
return x * F.gelu(gate)
|
40
|
+
|
41
|
+
|
42
|
+
class SwiGLU(nn.Module):
|
43
|
+
"""SwiGLU is an activation function which is a variant of GLU.
|
44
|
+
|
45
|
+
SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
|
46
|
+
swish function.
|
47
|
+
|
48
|
+
SwiGLU(x) = Swish(xW+b) * (xV+c)
|
49
|
+
See: https://paperswithcode.com/method/swiglu
|
45
50
|
"""
|
46
51
|
|
47
|
-
def
|
52
|
+
def forward(self, x: torch.Tensor):
|
48
53
|
x, y = x.chunk(2, dim=-1)
|
49
|
-
|
50
|
-
return act(x) * y
|
51
|
-
return x * act(y)
|
52
|
-
|
53
|
-
return _glu
|
54
|
+
return F.silu(x) * y
|
54
55
|
|
55
56
|
|
56
57
|
def build_norm(dim: int, config: cfg.NormalizationConfig):
|
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
|
|
151
152
|
# See: https://github.com/hendrycks/GELUs
|
152
153
|
return lambda x: x * F.sigmoid(1.702 * x)
|
153
154
|
elif config.type == cfg.ActivationType.GE_GLU:
|
154
|
-
return
|
155
|
+
return GeGLU(config.dim_in, config.dim_out)
|
155
156
|
elif config.type == cfg.ActivationType.RELU:
|
156
157
|
return F.relu
|
157
158
|
elif config.type == cfg.ActivationType.SILU_GLU:
|
158
|
-
return
|
159
|
+
return SwiGLU()
|
159
160
|
else:
|
160
161
|
raise ValueError("Unsupported activation type.")
|
@@ -118,9 +118,9 @@ class AttentionConfig:
|
|
118
118
|
@dataclass
|
119
119
|
class ActivationConfig:
|
120
120
|
type: ActivationType = ActivationType.LINEAR
|
121
|
-
#
|
122
|
-
|
123
|
-
|
121
|
+
# Dimension of input and output, used in GeGLU.
|
122
|
+
dim_in: Optional[int] = None
|
123
|
+
dim_out: Optional[int] = None
|
124
124
|
|
125
125
|
|
126
126
|
@dataclass
|
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
|
|
183
183
|
"""
|
184
184
|
x = torch.permute(x, (0, 2, 3, 1))
|
185
185
|
|
186
|
+
# TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
|
187
|
+
# int32 when the bug is fixed.
|
186
188
|
builder = StableHLOCompositeBuilder(
|
187
|
-
name="odml.group_norm",
|
189
|
+
name="odml.group_norm",
|
190
|
+
attr={
|
191
|
+
"num_groups": num_groups,
|
192
|
+
"eps": eps,
|
193
|
+
"reduction_axes": 3,
|
194
|
+
"channel_axis": 3,
|
195
|
+
},
|
188
196
|
)
|
189
197
|
x, w, b = builder.mark_inputs(x, w, b)
|
190
198
|
x = torch.permute(x, (0, 3, 1, 2))
|
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
|
|
206
214
|
"""Layer Normalization with high-level function boundary enabled.
|
207
215
|
|
208
216
|
Args:
|
209
|
-
x (torch.Tensor): Input tensor for Layer Normalization.
|
217
|
+
x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
|
210
218
|
w (torch.Tensor): The weight tensor for the normalization.
|
211
219
|
b (torch.Tensor): The bias tensor for the normalization.
|
212
220
|
eps (float): A small float value to ensure numerical stability.
|
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
|
|
216
224
|
Returns:
|
217
225
|
The output tensor of Layer Normalization.
|
218
226
|
"""
|
219
|
-
builder = StableHLOCompositeBuilder(
|
227
|
+
builder = StableHLOCompositeBuilder(
|
228
|
+
name="odml.group_norm",
|
229
|
+
attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
|
230
|
+
)
|
220
231
|
x, w, b = builder.mark_inputs(x, w, b)
|
221
232
|
if use_input_shape:
|
222
233
|
normalized_shape = x.shape
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240924
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
+
Requires-Dist: ai-edge-litert-nightly
|
33
34
|
Requires-Dist: ai-edge-quantizer-nightly
|
34
35
|
|
35
36
|
Library that supports converting PyTorch models into a .tflite format, which can
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=sQUcRP5rShDk3vfblz87j26JciN6PV8S8DJkiiZP5o8,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -48,12 +48,12 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
|
|
48
48
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
|
49
49
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
50
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
51
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
|
52
52
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
|
53
53
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
54
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
55
55
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
56
|
-
ai_edge_torch/generative/examples/phi/verify.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=SwPyRjiupD4AsmWW_7FDcMSWaNRmDBu6uVFcBQRoM40,2146
|
57
57
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
58
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
59
59
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
@@ -89,11 +89,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
89
89
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
90
90
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
91
91
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
92
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
92
|
+
ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
|
93
93
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
94
94
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
95
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
96
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
|
96
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=LDczSHujMgo1WV8IhYVQe-egPkaBEmWFt8wZQ_tgshg,6991
|
97
97
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
98
98
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
99
99
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -166,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
166
166
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
167
167
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
168
168
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
170
|
-
ai_edge_torch_nightly-0.3.0.
|
171
|
-
ai_edge_torch_nightly-0.3.0.
|
172
|
-
ai_edge_torch_nightly-0.3.0.
|
173
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/METADATA,sha256=BotYlw1pMxClnHOi8rSb5v6jX0zE7EqUo8b11xvqEII,1897
|
171
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
172
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
173
|
+
ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/RECORD,,
|
File without changes
|
File without changes
|