ai-edge-torch-nightly 0.3.0.dev20241006__py3-none-any.whl → 0.3.0.dev20241007__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -60,7 +60,6 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
60
60
  )
61
61
  norm_config = cfg.NormalizationConfig(
62
62
  type=cfg.NormalizationType.LAYER_NORM,
63
- use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
64
63
  )
65
64
  block_config = cfg.TransformerBlockConfig(
66
65
  attn_config=attn_config,
@@ -77,7 +77,7 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
77
77
  )
78
78
  elif config.type == cfg.NormalizationType.LAYER_NORM:
79
79
  return normalization.LayerNorm(
80
- dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ dim, config.epsilon, config.enable_hlfb
81
81
  )
82
82
  elif config.type == cfg.NormalizationType.GROUP_NORM:
83
83
  return normalization.GroupNorm(
@@ -69,9 +69,6 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
- # Whether to use the input shape to determine the dimension of normalization
73
- # when type is LAYER_NORM.
74
- use_input_shape: bool = True
75
72
 
76
73
 
77
74
  @dataclass
@@ -86,8 +86,8 @@ class GroupNorm(torch.nn.Module):
86
86
  self.enable_hlfb = enable_hlfb
87
87
  self.group_num = group_num
88
88
  self.eps = eps
89
- self.weight = torch.nn.Parameter(torch.ones(dim))
90
- self.bias = torch.nn.Parameter(torch.ones(dim))
89
+ self.weight = torch.nn.Parameter(torch.empty(dim))
90
+ self.bias = torch.nn.Parameter(torch.empty(dim))
91
91
 
92
92
  def forward(self, x):
93
93
  """Running the forward pass of GroupNorm layer.
@@ -117,25 +117,22 @@ class LayerNorm(torch.nn.Module):
117
117
  dim: int,
118
118
  eps: float = 1e-5,
119
119
  enable_hlfb: bool = False,
120
- use_input_shape: bool = True,
121
120
  ):
122
121
  """Initialize the LayerNorm layer.
123
122
 
124
123
  Args:
125
124
  dim (int): dimension of the input tensor.
126
125
  eps (float): A small float value to ensure numerical stability (default:
127
- 1e-6).
126
+ 1e-5).
128
127
  enable_hlfb (bool): Whether to convert this normalization into a single
129
128
  op.
130
- use_input_shape (bool): Whether to use the input shape to determine the
131
- dimension of normalization (default: True).
132
129
  """
133
130
  super().__init__()
134
131
  self.enable_hlfb = enable_hlfb
135
- self.use_input_shape = use_input_shape
132
+ self.normalized_shape = (dim,)
136
133
  self.eps = eps
137
- self.weight = torch.nn.Parameter(torch.ones(dim))
138
- self.bias = torch.nn.Parameter(torch.ones(dim))
134
+ self.weight = torch.nn.Parameter(torch.empty(dim))
135
+ self.bias = torch.nn.Parameter(torch.empty(dim))
139
136
 
140
137
  def forward(self, x):
141
138
  """Running the forward pass of LayerNorm layer.
@@ -148,18 +145,11 @@ class LayerNorm(torch.nn.Module):
148
145
  """
149
146
  if self.enable_hlfb:
150
147
  return layer_norm_with_hlfb(
151
- x, self.weight, self.bias, self.eps, self.use_input_shape
148
+ x, self.normalized_shape, self.weight, self.bias, self.eps
152
149
  )
153
-
154
- if self.use_input_shape:
155
- normalized_shape = x.shape
156
- weight = self.weight.broadcast_to(x.shape)
157
- bias = self.bias.broadcast_to(x.shape)
158
- else:
159
- normalized_shape = self.weight.shape
160
- weight = self.weight
161
- bias = self.bias
162
- return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
150
+ return F.layer_norm(
151
+ x, self.normalized_shape, self.weight, self.bias, self.eps
152
+ )
163
153
 
164
154
 
165
155
  def group_norm_with_hlfb(
@@ -206,20 +196,20 @@ def group_norm_with_hlfb(
206
196
 
207
197
  def layer_norm_with_hlfb(
208
198
  x: torch.Tensor,
199
+ normalized_shape: list[int],
209
200
  w: torch.Tensor,
210
201
  b: torch.Tensor,
211
202
  eps: float,
212
- use_input_shape: bool,
213
203
  ):
214
204
  """Layer Normalization with high-level function boundary enabled.
215
205
 
216
206
  Args:
217
207
  x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
208
+ normalized_shape (list[int]): Input shape from an expected input of size,
209
+ same as https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html.
218
210
  w (torch.Tensor): The weight tensor for the normalization.
219
211
  b (torch.Tensor): The bias tensor for the normalization.
220
212
  eps (float): A small float value to ensure numerical stability.
221
- use_input_shape (bool): Whether to use the input shape to determine the
222
- dimension of normalization.
223
213
 
224
214
  Returns:
225
215
  The output tensor of Layer Normalization.
@@ -229,12 +219,6 @@ def layer_norm_with_hlfb(
229
219
  attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
220
  )
231
221
  x, w, b = builder.mark_inputs(x, w, b)
232
- if use_input_shape:
233
- normalized_shape = x.shape
234
- w = w.broadcast_to(x.shape)
235
- b = b.broadcast_to(x.shape)
236
- else:
237
- normalized_shape = w.shape
238
222
  y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
239
223
  y = builder.mark_outputs(y)
240
224
  return y
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241006"
16
+ __version__ = "0.3.0.dev20241007"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241006
3
+ Version: 0.3.0.dev20241007
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=VXOiPAspd6Q0ZIwnjLplRQtIy97iE7lFvqOIWCFNigI,706
6
+ ai_edge_torch/version.py,sha256=G4658C8Iyg3OGjdMypYdKW18bjZ3ysuHhek8VfS6mCc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3Tif
57
57
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
59
59
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
60
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=CQ55KfOdoOM43CxF7yNQsgq8b-j0S50bXpxYzgq-keM,3418
60
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
61
61
  ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
62
62
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
63
63
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
@@ -101,11 +101,11 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
101
101
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
102
102
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
103
103
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
104
- ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
104
+ ai_edge_torch/generative/layers/builder.py,sha256=XyZS1RrnMbvypeLMfwU7h1Y4x5r4WGgOx2YGJF0OUNQ,5064
105
105
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
106
106
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
107
- ai_edge_torch/generative/layers/model_config.py,sha256=xZt4xaNZJPvtdy4hfbnRencEENr689zO0WnZbhpNTIs,7137
108
- ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
107
+ ai_edge_torch/generative/layers/model_config.py,sha256=DdsdhTP5tZAtyWim-qW2m8HDBsYbs7boqSDb83vwgmE,6998
108
+ ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
109
109
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
110
110
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
111
111
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/METADATA,sha256=WCkiyqsYONZTjrZ5OS0O3V4jK8kHc9MHDI6iJiywO9k,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241006.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/METADATA,sha256=GgrT0zBBoJAVXVdeYnu9-ZaI2uFn9zpd1BS4OEVPNSY,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241007.dist-info/RECORD,,