ai-edge-torch-nightly 0.3.0.dev20240906__py3-none-any.whl → 0.3.0.dev20240909__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -150,6 +150,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
150
150
  # ==== Ops must be NHWC if possible
151
151
 
152
152
 
153
+ @layout_sensitive_inputs_getters.register(aten.conv2d)
153
154
  @layout_sensitive_inputs_getters.register(aten.convolution)
154
155
  @layout_sensitive_inputs_getters.register(
155
156
  aten._native_batch_norm_legit_no_training
@@ -168,6 +169,7 @@ def _first_arg_getter(node):
168
169
  @nhwcable_node_checkers.register(aten.upsample_bilinear2d)
169
170
  @nhwcable_node_checkers.register(aten.upsample_nearest2d)
170
171
  @nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
172
+ @nhwcable_node_checkers.register(aten.conv2d)
171
173
  @nhwcable_node_checkers.register(aten.convolution)
172
174
  def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
173
175
  can_be = all_layout_sensitive_inputs_are_4d(node)
@@ -229,11 +229,12 @@ def transpose_first_arg_rewriter(node: torch.fx.Node):
229
229
  node.target = nhwc_op
230
230
 
231
231
 
232
+ @rewriters.register(aten.conv2d)
232
233
  @rewriters.register(aten.convolution)
233
234
  def _aten_convolution_rewriter(node: torch.fx.Node):
234
235
  op = node.target
235
236
 
236
- def conv_nhwc(input, weight, bias, *args, **kwargs):
237
+ def conv_nhwc(input, weight, bias=None, *args, **kwargs):
237
238
  nonlocal op
238
239
  nhwc_bias = None
239
240
  if bias is not None and len(bias.shape) == 1:
@@ -145,14 +145,15 @@ class AttentionBlock2D(nn.Module):
145
145
  x = x.view(B, C, H * W)
146
146
  x = x.transpose(-1, -2)
147
147
  else:
148
- x = input_tensor.view(B, C, H * W)
149
- x = x.transpose(-1, -2)
148
+ x = torch.permute(input_tensor, (0, 2, 3, 1))
150
149
  x = self.norm(x)
150
+ x = x.view(B, H * W, C)
151
151
  x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
152
152
  x = self.attention(x)
153
- x = x.transpose(-1, -2)
154
- x = x.view(B, C, H, W)
153
+ x = x.view(B, H, W, C)
154
+ residual = torch.permute(residual, (0, 2, 3, 1))
155
155
  x = x + residual
156
+ x = torch.permute(x, (0, 3, 1, 2))
156
157
  return x
157
158
 
158
159
 
@@ -206,13 +207,14 @@ class CrossAttentionBlock2D(nn.Module):
206
207
  x = x.view(B, C, H * W)
207
208
  x = x.transpose(-1, -2)
208
209
  else:
209
- x = input_tensor.view(B, C, H * W)
210
- x = x.transpose(-1, -2)
210
+ x = torch.permute(input_tensor, (0, 2, 3, 1))
211
211
  x = self.norm(x)
212
+ x = x.view(B, H * W, C)
212
213
  x = self.attention(x, context_tensor)
213
- x = x.transpose(-1, -2)
214
- x = x.view(B, C, H, W)
214
+ x = x.view(B, H, W, C)
215
+ residual = torch.permute(residual, (0, 2, 3, 1))
215
216
  x = x + residual
217
+ x = torch.permute(x, (0, 3, 1, 2))
216
218
  return x
217
219
 
218
220
 
@@ -250,17 +252,17 @@ class FeedForwardBlock2D(nn.Module):
250
252
  x = x.view(B, C, H * W)
251
253
  x = x.transpose(-1, -2)
252
254
  else:
253
- x = input_tensor.view(B, C, H * W)
254
- x = x.transpose(-1, -2)
255
+ x = torch.permute(input_tensor, (0, 2, 3, 1))
255
256
  x = self.norm(x)
257
+ x = x.view(B, H * W, C)
256
258
  x = self.w1(x)
257
259
  x = self.act(x)
258
260
  x = self.w2(x)
259
-
260
- x = x.transpose(-1, -2) # (B, C, HW)
261
- x = x.view((B, C, H, W))
262
-
263
- return x + residual
261
+ x = x.view(B, H, W, C)
262
+ residual = torch.permute(residual, (0, 2, 3, 1))
263
+ x = x + residual
264
+ x = torch.permute(x, (0, 3, 1, 2))
265
+ return x
264
266
 
265
267
 
266
268
  class TransformerBlock2D(nn.Module):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240906"
16
+ __version__ = "0.3.0.dev20240909"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240906
3
+ Version: 0.3.0.dev20240909
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
5
+ ai_edge_torch/version.py,sha256=r0y6crIySNGhJqtljkzyHxb1XMvLji2VLajLfUjW8b4,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=izep
16
16
  ai_edge_torch/_convert/fx_passes/canonicalize_pass.py,sha256=8jcKqWzG7p5r3Cu7DXNP-4o4X2bqLaoXY7N6W8QsZXo,1582
17
17
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=WKI8V9-V50agkiNVpBFWWp0BEpUfemdENuN1cEaGD-g,2370
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=DIfrWDZ1ufAN_uH-oW3k66jTciY7DlLDAb6UKMN14zE,7528
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=e_JWgGFOSUI9DUtmod396GNH9uJNd2VBL0DXGjbg-cE,12702
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=HXTDEP6_Z0I0s58H6I0yHz9qrkOxptIjKhxywfe8F80,10637
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -96,7 +96,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQ
96
96
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
97
97
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
98
98
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
99
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
100
100
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
101
101
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
102
102
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/METADATA,sha256=s7SAIUvFciy8peNKMHvyhoNQWYx67Jerz4foeV7KiE0,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/RECORD,,