ai-edge-torch-nightly 0.3.0.dev20240906__py3-none-any.whl → 0.3.0.dev20240909__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -150,6 +150,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
150
150
  # ==== Ops must be NHWC if possible
151
151
 
152
152
 
153
+ @layout_sensitive_inputs_getters.register(aten.conv2d)
153
154
  @layout_sensitive_inputs_getters.register(aten.convolution)
154
155
  @layout_sensitive_inputs_getters.register(
155
156
  aten._native_batch_norm_legit_no_training
@@ -168,6 +169,7 @@ def _first_arg_getter(node):
168
169
  @nhwcable_node_checkers.register(aten.upsample_bilinear2d)
169
170
  @nhwcable_node_checkers.register(aten.upsample_nearest2d)
170
171
  @nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
172
+ @nhwcable_node_checkers.register(aten.conv2d)
171
173
  @nhwcable_node_checkers.register(aten.convolution)
172
174
  def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
173
175
  can_be = all_layout_sensitive_inputs_are_4d(node)
@@ -229,11 +229,12 @@ def transpose_first_arg_rewriter(node: torch.fx.Node):
229
229
  node.target = nhwc_op
230
230
 
231
231
 
232
+ @rewriters.register(aten.conv2d)
232
233
  @rewriters.register(aten.convolution)
233
234
  def _aten_convolution_rewriter(node: torch.fx.Node):
234
235
  op = node.target
235
236
 
236
- def conv_nhwc(input, weight, bias, *args, **kwargs):
237
+ def conv_nhwc(input, weight, bias=None, *args, **kwargs):
237
238
  nonlocal op
238
239
  nhwc_bias = None
239
240
  if bias is not None and len(bias.shape) == 1:
@@ -145,14 +145,15 @@ class AttentionBlock2D(nn.Module):
145
145
  x = x.view(B, C, H * W)
146
146
  x = x.transpose(-1, -2)
147
147
  else:
148
- x = input_tensor.view(B, C, H * W)
149
- x = x.transpose(-1, -2)
148
+ x = torch.permute(input_tensor, (0, 2, 3, 1))
150
149
  x = self.norm(x)
150
+ x = x.view(B, H * W, C)
151
151
  x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
152
152
  x = self.attention(x)
153
- x = x.transpose(-1, -2)
154
- x = x.view(B, C, H, W)
153
+ x = x.view(B, H, W, C)
154
+ residual = torch.permute(residual, (0, 2, 3, 1))
155
155
  x = x + residual
156
+ x = torch.permute(x, (0, 3, 1, 2))
156
157
  return x
157
158
 
158
159
 
@@ -206,13 +207,14 @@ class CrossAttentionBlock2D(nn.Module):
206
207
  x = x.view(B, C, H * W)
207
208
  x = x.transpose(-1, -2)
208
209
  else:
209
- x = input_tensor.view(B, C, H * W)
210
- x = x.transpose(-1, -2)
210
+ x = torch.permute(input_tensor, (0, 2, 3, 1))
211
211
  x = self.norm(x)
212
+ x = x.view(B, H * W, C)
212
213
  x = self.attention(x, context_tensor)
213
- x = x.transpose(-1, -2)
214
- x = x.view(B, C, H, W)
214
+ x = x.view(B, H, W, C)
215
+ residual = torch.permute(residual, (0, 2, 3, 1))
215
216
  x = x + residual
217
+ x = torch.permute(x, (0, 3, 1, 2))
216
218
  return x
217
219
 
218
220
 
@@ -250,17 +252,17 @@ class FeedForwardBlock2D(nn.Module):
250
252
  x = x.view(B, C, H * W)
251
253
  x = x.transpose(-1, -2)
252
254
  else:
253
- x = input_tensor.view(B, C, H * W)
254
- x = x.transpose(-1, -2)
255
+ x = torch.permute(input_tensor, (0, 2, 3, 1))
255
256
  x = self.norm(x)
257
+ x = x.view(B, H * W, C)
256
258
  x = self.w1(x)
257
259
  x = self.act(x)
258
260
  x = self.w2(x)
259
-
260
- x = x.transpose(-1, -2) # (B, C, HW)
261
- x = x.view((B, C, H, W))
262
-
263
- return x + residual
261
+ x = x.view(B, H, W, C)
262
+ residual = torch.permute(residual, (0, 2, 3, 1))
263
+ x = x + residual
264
+ x = torch.permute(x, (0, 3, 1, 2))
265
+ return x
264
266
 
265
267
 
266
268
  class TransformerBlock2D(nn.Module):
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240906"
16
+ __version__ = "0.3.0.dev20240909"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240906
3
+ Version: 0.3.0.dev20240909
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
5
+ ai_edge_torch/version.py,sha256=r0y6crIySNGhJqtljkzyHxb1XMvLji2VLajLfUjW8b4,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=izep
16
16
  ai_edge_torch/_convert/fx_passes/canonicalize_pass.py,sha256=8jcKqWzG7p5r3Cu7DXNP-4o4X2bqLaoXY7N6W8QsZXo,1582
17
17
  ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=WKI8V9-V50agkiNVpBFWWp0BEpUfemdENuN1cEaGD-g,2370
18
18
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
19
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=DIfrWDZ1ufAN_uH-oW3k66jTciY7DlLDAb6UKMN14zE,7528
19
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
20
20
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
21
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=e_JWgGFOSUI9DUtmod396GNH9uJNd2VBL0DXGjbg-cE,12702
21
+ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
22
22
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
23
23
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=HXTDEP6_Z0I0s58H6I0yHz9qrkOxptIjKhxywfe8F80,10637
24
24
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
@@ -96,7 +96,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQ
96
96
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
97
97
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
98
98
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
99
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
100
100
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
101
101
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
102
102
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/METADATA,sha256=s7SAIUvFciy8peNKMHvyhoNQWYx67Jerz4foeV7KiE0,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/RECORD,,