ai-edge-torch-nightly 0.3.0.dev20241006__py3-none-any.whl → 0.3.0.dev20241007__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/phi/phi2.py +0 -1
- ai_edge_torch/generative/layers/builder.py +1 -1
- ai_edge_torch/generative/layers/model_config.py +0 -3
- ai_edge_torch/generative/layers/normalization.py +13 -29
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
60
60
|
)
|
61
61
|
norm_config = cfg.NormalizationConfig(
|
62
62
|
type=cfg.NormalizationType.LAYER_NORM,
|
63
|
-
use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
|
64
63
|
)
|
65
64
|
block_config = cfg.TransformerBlockConfig(
|
66
65
|
attn_config=attn_config,
|
@@ -77,7 +77,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
77
77
|
)
|
78
78
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
79
79
|
return normalization.LayerNorm(
|
80
|
-
dim, config.epsilon, config.enable_hlfb
|
80
|
+
dim, config.epsilon, config.enable_hlfb
|
81
81
|
)
|
82
82
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
83
83
|
return normalization.GroupNorm(
|
@@ -69,9 +69,6 @@ class NormalizationConfig:
|
|
69
69
|
zero_centered: bool = False
|
70
70
|
# Number of groups used in group normalization.
|
71
71
|
group_num: Optional[float] = None
|
72
|
-
# Whether to use the input shape to determine the dimension of normalization
|
73
|
-
# when type is LAYER_NORM.
|
74
|
-
use_input_shape: bool = True
|
75
72
|
|
76
73
|
|
77
74
|
@dataclass
|
@@ -86,8 +86,8 @@ class GroupNorm(torch.nn.Module):
|
|
86
86
|
self.enable_hlfb = enable_hlfb
|
87
87
|
self.group_num = group_num
|
88
88
|
self.eps = eps
|
89
|
-
self.weight = torch.nn.Parameter(torch.
|
90
|
-
self.bias = torch.nn.Parameter(torch.
|
89
|
+
self.weight = torch.nn.Parameter(torch.empty(dim))
|
90
|
+
self.bias = torch.nn.Parameter(torch.empty(dim))
|
91
91
|
|
92
92
|
def forward(self, x):
|
93
93
|
"""Running the forward pass of GroupNorm layer.
|
@@ -117,25 +117,22 @@ class LayerNorm(torch.nn.Module):
|
|
117
117
|
dim: int,
|
118
118
|
eps: float = 1e-5,
|
119
119
|
enable_hlfb: bool = False,
|
120
|
-
use_input_shape: bool = True,
|
121
120
|
):
|
122
121
|
"""Initialize the LayerNorm layer.
|
123
122
|
|
124
123
|
Args:
|
125
124
|
dim (int): dimension of the input tensor.
|
126
125
|
eps (float): A small float value to ensure numerical stability (default:
|
127
|
-
1e-
|
126
|
+
1e-5).
|
128
127
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
129
128
|
op.
|
130
|
-
use_input_shape (bool): Whether to use the input shape to determine the
|
131
|
-
dimension of normalization (default: True).
|
132
129
|
"""
|
133
130
|
super().__init__()
|
134
131
|
self.enable_hlfb = enable_hlfb
|
135
|
-
self.
|
132
|
+
self.normalized_shape = (dim,)
|
136
133
|
self.eps = eps
|
137
|
-
self.weight = torch.nn.Parameter(torch.
|
138
|
-
self.bias = torch.nn.Parameter(torch.
|
134
|
+
self.weight = torch.nn.Parameter(torch.empty(dim))
|
135
|
+
self.bias = torch.nn.Parameter(torch.empty(dim))
|
139
136
|
|
140
137
|
def forward(self, x):
|
141
138
|
"""Running the forward pass of LayerNorm layer.
|
@@ -148,18 +145,11 @@ class LayerNorm(torch.nn.Module):
|
|
148
145
|
"""
|
149
146
|
if self.enable_hlfb:
|
150
147
|
return layer_norm_with_hlfb(
|
151
|
-
x, self.
|
148
|
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
152
149
|
)
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
weight = self.weight.broadcast_to(x.shape)
|
157
|
-
bias = self.bias.broadcast_to(x.shape)
|
158
|
-
else:
|
159
|
-
normalized_shape = self.weight.shape
|
160
|
-
weight = self.weight
|
161
|
-
bias = self.bias
|
162
|
-
return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
|
150
|
+
return F.layer_norm(
|
151
|
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
152
|
+
)
|
163
153
|
|
164
154
|
|
165
155
|
def group_norm_with_hlfb(
|
@@ -206,20 +196,20 @@ def group_norm_with_hlfb(
|
|
206
196
|
|
207
197
|
def layer_norm_with_hlfb(
|
208
198
|
x: torch.Tensor,
|
199
|
+
normalized_shape: list[int],
|
209
200
|
w: torch.Tensor,
|
210
201
|
b: torch.Tensor,
|
211
202
|
eps: float,
|
212
|
-
use_input_shape: bool,
|
213
203
|
):
|
214
204
|
"""Layer Normalization with high-level function boundary enabled.
|
215
205
|
|
216
206
|
Args:
|
217
207
|
x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
|
208
|
+
normalized_shape (list[int]): Input shape from an expected input of size,
|
209
|
+
same as https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
|
218
210
|
w (torch.Tensor): The weight tensor for the normalization.
|
219
211
|
b (torch.Tensor): The bias tensor for the normalization.
|
220
212
|
eps (float): A small float value to ensure numerical stability.
|
221
|
-
use_input_shape (bool): Whether to use the input shape to determine the
|
222
|
-
dimension of normalization.
|
223
213
|
|
224
214
|
Returns:
|
225
215
|
The output tensor of Layer Normalization.
|
@@ -229,12 +219,6 @@ def layer_norm_with_hlfb(
|
|
229
219
|
attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
|
230
220
|
)
|
231
221
|
x, w, b = builder.mark_inputs(x, w, b)
|
232
|
-
if use_input_shape:
|
233
|
-
normalized_shape = x.shape
|
234
|
-
w = w.broadcast_to(x.shape)
|
235
|
-
b = b.broadcast_to(x.shape)
|
236
|
-
else:
|
237
|
-
normalized_shape = w.shape
|
238
222
|
y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
|
239
223
|
y = builder.mark_outputs(y)
|
240
224
|
return y
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241007
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=G4658C8Iyg3OGjdMypYdKW18bjZ3ysuHhek8VfS6mCc,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
|
|
57
57
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
58
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
59
59
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
60
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
|
61
61
|
ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
|
62
62
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
63
63
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
@@ -101,11 +101,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
101
101
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
102
102
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
103
103
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
104
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
104
|
+
ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
|
105
105
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
106
106
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
107
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
108
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
107
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
|
108
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
109
109
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
110
110
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
111
111
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/METADATA,sha256=GgrT0zBBoJAVXVdeYnu9-ZaI2uFn9zpd1BS4OEVPNSY,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/RECORD,,
|
File without changes
|
File without changes
|