ai-edge-torch-nightly 0.3.0.dev20241006__py3-none-any.whl → 0.3.0.dev20241007__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -60,7 +60,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
60
60
  )
61
61
  norm_config = cfg.NormalizationConfig(
62
62
  type=cfg.NormalizationType.LAYER_NORM,
63
- use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
64
63
  )
65
64
  block_config = cfg.TransformerBlockConfig(
66
65
  attn_config=attn_config,
@@ -77,7 +77,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
77
77
  )
78
78
  elif config.type == cfg.NormalizationType.LAYER_NORM:
79
79
  return normalization.LayerNorm(
80
- dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ dim, config.epsilon, config.enable_hlfb
81
81
  )
82
82
  elif config.type == cfg.NormalizationType.GROUP_NORM:
83
83
  return normalization.GroupNorm(
@@ -69,9 +69,6 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
- # Whether to use the input shape to determine the dimension of normalization
73
- # when type is LAYER_NORM.
74
- use_input_shape: bool = True
75
72
 
76
73
 
77
74
  @dataclass
@@ -86,8 +86,8 @@ class GroupNorm(torch.nn.Module):
86
86
  self.enable_hlfb = enable_hlfb
87
87
  self.group_num = group_num
88
88
  self.eps = eps
89
- self.weight = torch.nn.Parameter(torch.ones(dim))
90
- self.bias = torch.nn.Parameter(torch.ones(dim))
89
+ self.weight = torch.nn.Parameter(torch.empty(dim))
90
+ self.bias = torch.nn.Parameter(torch.empty(dim))
91
91
 
92
92
  def forward(self, x):
93
93
  """Running the forward pass of GroupNorm layer.
@@ -117,25 +117,22 @@ class LayerNorm(torch.nn.Module):
117
117
  dim: int,
118
118
  eps: float = 1e-5,
119
119
  enable_hlfb: bool = False,
120
- use_input_shape: bool = True,
121
120
  ):
122
121
  """Initialize the LayerNorm layer.
123
122
 
124
123
  Args:
125
124
  dim (int): dimension of the input tensor.
126
125
  eps (float): A small float value to ensure numerical stability (default:
127
- 1e-6).
126
+ 1e-5).
128
127
  enable_hlfb (bool): Whether to convert this normalization into a single
129
128
  op.
130
- use_input_shape (bool): Whether to use the input shape to determine the
131
- dimension of normalization (default: True).
132
129
  """
133
130
  super().__init__()
134
131
  self.enable_hlfb = enable_hlfb
135
- self.use_input_shape = use_input_shape
132
+ self.normalized_shape = (dim,)
136
133
  self.eps = eps
137
- self.weight = torch.nn.Parameter(torch.ones(dim))
138
- self.bias = torch.nn.Parameter(torch.ones(dim))
134
+ self.weight = torch.nn.Parameter(torch.empty(dim))
135
+ self.bias = torch.nn.Parameter(torch.empty(dim))
139
136
 
140
137
  def forward(self, x):
141
138
  """Running the forward pass of LayerNorm layer.
@@ -148,18 +145,11 @@ class LayerNorm(torch.nn.Module):
148
145
  """
149
146
  if self.enable_hlfb:
150
147
  return layer_norm_with_hlfb(
151
- x, self.weight, self.bias, self.eps, self.use_input_shape
148
+ x, self.normalized_shape, self.weight, self.bias, self.eps
152
149
  )
153
-
154
- if self.use_input_shape:
155
- normalized_shape = x.shape
156
- weight = self.weight.broadcast_to(x.shape)
157
- bias = self.bias.broadcast_to(x.shape)
158
- else:
159
- normalized_shape = self.weight.shape
160
- weight = self.weight
161
- bias = self.bias
162
- return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
150
+ return F.layer_norm(
151
+ x, self.normalized_shape, self.weight, self.bias, self.eps
152
+ )
163
153
 
164
154
 
165
155
  def group_norm_with_hlfb(
@@ -206,20 +196,20 @@ def group_norm_with_hlfb(
206
196
 
207
197
  def layer_norm_with_hlfb(
208
198
  x: torch.Tensor,
199
+ normalized_shape: list[int],
209
200
  w: torch.Tensor,
210
201
  b: torch.Tensor,
211
202
  eps: float,
212
- use_input_shape: bool,
213
203
  ):
214
204
  """Layer Normalization with high-level function boundary enabled.
215
205
 
216
206
  Args:
217
207
  x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
208
+ normalized_shape (list[int]): Input shape from an expected input of size,
209
+ same as https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
218
210
  w (torch.Tensor): The weight tensor for the normalization.
219
211
  b (torch.Tensor): The bias tensor for the normalization.
220
212
  eps (float): A small float value to ensure numerical stability.
221
- use_input_shape (bool): Whether to use the input shape to determine the
222
- dimension of normalization.
223
213
 
224
214
  Returns:
225
215
  The output tensor of Layer Normalization.
@@ -229,12 +219,6 @@ def layer_norm_with_hlfb(
229
219
  attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
220
  )
231
221
  x, w, b = builder.mark_inputs(x, w, b)
232
- if use_input_shape:
233
- normalized_shape = x.shape
234
- w = w.broadcast_to(x.shape)
235
- b = b.broadcast_to(x.shape)
236
- else:
237
- normalized_shape = w.shape
238
222
  y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
239
223
  y = builder.mark_outputs(y)
240
224
  return y
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241006"
16
+ __version__ = "0.3.0.dev20241007"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241006
3
+ Version: 0.3.0.dev20241007
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=VXOiPAspd6Q0ZIwnjLplRQtIy97iE7lFvqOIWCFNigI,706
6
+ ai_edge_torch/version.py,sha256=G4658C8Iyg3OGjdMypYdKW18bjZ3ysuHhek8VfS6mCc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
57
57
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
59
59
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
60
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=CQ55KfOdoOM43CxF7yNQsgq8b-j0S50bXpxYzgq-keM,3418
60
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
61
61
  ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
62
62
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
63
63
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
@@ -101,11 +101,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
101
101
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
102
102
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
103
103
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
104
- ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
104
+ ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
105
105
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
106
106
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
107
- ai_edge_torch/generative/layers/model_config.py,sha256=xZt4xaNZJPvtdy4hfbnRencEENr689zO0WnZbhpNTIs,7137
108
- ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
107
+ ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
108
+ ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
109
109
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
110
110
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
111
111
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/METADATA,sha256=WCkiyqsYONZTjrZ5OS0O3V4jK8kHc9MHDI6iJiywO9k,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/METADATA,sha256=GgrT0zBBoJAVXVdeYnu9-ZaI2uFn9zpd1BS4OEVPNSY,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/RECORD,,