ai-edge-torch-nightly 0.3.0.dev20241006__py3-none-any.whl → 0.3.0.dev20241007__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/phi/phi2.py +0 -1
- ai_edge_torch/generative/layers/builder.py +1 -1
- ai_edge_torch/generative/layers/model_config.py +0 -3
- ai_edge_torch/generative/layers/normalization.py +13 -29
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/RECORD +10 -10
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241006.dist-info → ai_edge_torch_nightly-0.3.0.dev20241007.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
60
60
|
)
|
61
61
|
norm_config = cfg.NormalizationConfig(
|
62
62
|
type=cfg.NormalizationType.LAYER_NORM,
|
63
|
-
use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
|
64
63
|
)
|
65
64
|
block_config = cfg.TransformerBlockConfig(
|
66
65
|
attn_config=attn_config,
|
@@ -77,7 +77,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
|
|
77
77
|
)
|
78
78
|
elif config.type == cfg.NormalizationType.LAYER_NORM:
|
79
79
|
return normalization.LayerNorm(
|
80
|
-
dim, config.epsilon, config.enable_hlfb
|
80
|
+
dim, config.epsilon, config.enable_hlfb
|
81
81
|
)
|
82
82
|
elif config.type == cfg.NormalizationType.GROUP_NORM:
|
83
83
|
return normalization.GroupNorm(
|
@@ -69,9 +69,6 @@ class NormalizationConfig:
|
|
69
69
|
zero_centered: bool = False
|
70
70
|
# Number of groups used in group normalization.
|
71
71
|
group_num: Optional[float] = None
|
72
|
-
# Whether to use the input shape to determine the dimension of normalization
|
73
|
-
# when type is LAYER_NORM.
|
74
|
-
use_input_shape: bool = True
|
75
72
|
|
76
73
|
|
77
74
|
@dataclass
|
@@ -86,8 +86,8 @@ class GroupNorm(torch.nn.Module):
|
|
86
86
|
self.enable_hlfb = enable_hlfb
|
87
87
|
self.group_num = group_num
|
88
88
|
self.eps = eps
|
89
|
-
self.weight = torch.nn.Parameter(torch.
|
90
|
-
self.bias = torch.nn.Parameter(torch.
|
89
|
+
self.weight = torch.nn.Parameter(torch.empty(dim))
|
90
|
+
self.bias = torch.nn.Parameter(torch.empty(dim))
|
91
91
|
|
92
92
|
def forward(self, x):
|
93
93
|
"""Running the forward pass of GroupNorm layer.
|
@@ -117,25 +117,22 @@ class LayerNorm(torch.nn.Module):
|
|
117
117
|
dim: int,
|
118
118
|
eps: float = 1e-5,
|
119
119
|
enable_hlfb: bool = False,
|
120
|
-
use_input_shape: bool = True,
|
121
120
|
):
|
122
121
|
"""Initialize the LayerNorm layer.
|
123
122
|
|
124
123
|
Args:
|
125
124
|
dim (int): dimension of the input tensor.
|
126
125
|
eps (float): A small float value to ensure numerical stability (default:
|
127
|
-
1e-
|
126
|
+
1e-5).
|
128
127
|
enable_hlfb (bool): Whether to convert this normalization into a single
|
129
128
|
op.
|
130
|
-
use_input_shape (bool): Whether to use the input shape to determine the
|
131
|
-
dimension of normalization (default: True).
|
132
129
|
"""
|
133
130
|
super().__init__()
|
134
131
|
self.enable_hlfb = enable_hlfb
|
135
|
-
self.
|
132
|
+
self.normalized_shape = (dim,)
|
136
133
|
self.eps = eps
|
137
|
-
self.weight = torch.nn.Parameter(torch.
|
138
|
-
self.bias = torch.nn.Parameter(torch.
|
134
|
+
self.weight = torch.nn.Parameter(torch.empty(dim))
|
135
|
+
self.bias = torch.nn.Parameter(torch.empty(dim))
|
139
136
|
|
140
137
|
def forward(self, x):
|
141
138
|
"""Running the forward pass of LayerNorm layer.
|
@@ -148,18 +145,11 @@ class LayerNorm(torch.nn.Module):
|
|
148
145
|
"""
|
149
146
|
if self.enable_hlfb:
|
150
147
|
return layer_norm_with_hlfb(
|
151
|
-
x, self.
|
148
|
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
152
149
|
)
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
weight = self.weight.broadcast_to(x.shape)
|
157
|
-
bias = self.bias.broadcast_to(x.shape)
|
158
|
-
else:
|
159
|
-
normalized_shape = self.weight.shape
|
160
|
-
weight = self.weight
|
161
|
-
bias = self.bias
|
162
|
-
return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
|
150
|
+
return F.layer_norm(
|
151
|
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
152
|
+
)
|
163
153
|
|
164
154
|
|
165
155
|
def group_norm_with_hlfb(
|
@@ -206,20 +196,20 @@ def group_norm_with_hlfb(
|
|
206
196
|
|
207
197
|
def layer_norm_with_hlfb(
|
208
198
|
x: torch.Tensor,
|
199
|
+
normalized_shape: list[int],
|
209
200
|
w: torch.Tensor,
|
210
201
|
b: torch.Tensor,
|
211
202
|
eps: float,
|
212
|
-
use_input_shape: bool,
|
213
203
|
):
|
214
204
|
"""Layer Normalization with high-level function boundary enabled.
|
215
205
|
|
216
206
|
Args:
|
217
207
|
x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
|
208
|
+
normalized_shape (list[int]): Input shape from an expected input of size,
|
209
|
+
same as https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
|
218
210
|
w (torch.Tensor): The weight tensor for the normalization.
|
219
211
|
b (torch.Tensor): The bias tensor for the normalization.
|
220
212
|
eps (float): A small float value to ensure numerical stability.
|
221
|
-
use_input_shape (bool): Whether to use the input shape to determine the
|
222
|
-
dimension of normalization.
|
223
213
|
|
224
214
|
Returns:
|
225
215
|
The output tensor of Layer Normalization.
|
@@ -229,12 +219,6 @@ def layer_norm_with_hlfb(
|
|
229
219
|
attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
|
230
220
|
)
|
231
221
|
x, w, b = builder.mark_inputs(x, w, b)
|
232
|
-
if use_input_shape:
|
233
|
-
normalized_shape = x.shape
|
234
|
-
w = w.broadcast_to(x.shape)
|
235
|
-
b = b.broadcast_to(x.shape)
|
236
|
-
else:
|
237
|
-
normalized_shape = w.shape
|
238
222
|
y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
|
239
223
|
y = builder.mark_outputs(y)
|
240
224
|
return y
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241007
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=G4658C8Iyg3OGjdMypYdKW18bjZ3ysuHhek8VfS6mCc,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
|
|
57
57
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
58
58
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
59
59
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
60
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
|
61
61
|
ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
|
62
62
|
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
63
63
|
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
@@ -101,11 +101,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
|
|
101
101
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
102
102
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
103
103
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
104
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
104
|
+
ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
|
105
105
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
106
106
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
107
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
108
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
107
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
|
108
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
109
109
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
110
110
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
111
111
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/METADATA,sha256=GgrT0zBBoJAVXVdeYnu9-ZaI2uFn9zpd1BS4OEVPNSY,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/RECORD,,
|
File without changes
|
File without changes
|