ai-edge-torch-nightly 0.3.0.dev20240906__py3-none-any.whl → 0.3.0.dev20240908__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +2 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/layers/unet/blocks_2d.py +17 -15
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/top_level.txt +0 -0
@@ -150,6 +150,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
150
150
|
# ==== Ops must be NHWC if possible
|
151
151
|
|
152
152
|
|
153
|
+
@layout_sensitive_inputs_getters.register(aten.conv2d)
|
153
154
|
@layout_sensitive_inputs_getters.register(aten.convolution)
|
154
155
|
@layout_sensitive_inputs_getters.register(
|
155
156
|
aten._native_batch_norm_legit_no_training
|
@@ -168,6 +169,7 @@ def _first_arg_getter(node):
|
|
168
169
|
@nhwcable_node_checkers.register(aten.upsample_bilinear2d)
|
169
170
|
@nhwcable_node_checkers.register(aten.upsample_nearest2d)
|
170
171
|
@nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
|
172
|
+
@nhwcable_node_checkers.register(aten.conv2d)
|
171
173
|
@nhwcable_node_checkers.register(aten.convolution)
|
172
174
|
def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
173
175
|
can_be = all_layout_sensitive_inputs_are_4d(node)
|
@@ -229,11 +229,12 @@ def transpose_first_arg_rewriter(node: torch.fx.Node):
|
|
229
229
|
node.target = nhwc_op
|
230
230
|
|
231
231
|
|
232
|
+
@rewriters.register(aten.conv2d)
|
232
233
|
@rewriters.register(aten.convolution)
|
233
234
|
def _aten_convolution_rewriter(node: torch.fx.Node):
|
234
235
|
op = node.target
|
235
236
|
|
236
|
-
def conv_nhwc(input, weight, bias, *args, **kwargs):
|
237
|
+
def conv_nhwc(input, weight, bias=None, *args, **kwargs):
|
237
238
|
nonlocal op
|
238
239
|
nhwc_bias = None
|
239
240
|
if bias is not None and len(bias.shape) == 1:
|
@@ -145,14 +145,15 @@ class AttentionBlock2D(nn.Module):
|
|
145
145
|
x = x.view(B, C, H * W)
|
146
146
|
x = x.transpose(-1, -2)
|
147
147
|
else:
|
148
|
-
x =
|
149
|
-
x = x.transpose(-1, -2)
|
148
|
+
x = torch.permute(input_tensor, (0, 2, 3, 1))
|
150
149
|
x = self.norm(x)
|
150
|
+
x = x.view(B, H * W, C)
|
151
151
|
x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
|
152
152
|
x = self.attention(x)
|
153
|
-
x = x.
|
154
|
-
|
153
|
+
x = x.view(B, H, W, C)
|
154
|
+
residual = torch.permute(residual, (0, 2, 3, 1))
|
155
155
|
x = x + residual
|
156
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
156
157
|
return x
|
157
158
|
|
158
159
|
|
@@ -206,13 +207,14 @@ class CrossAttentionBlock2D(nn.Module):
|
|
206
207
|
x = x.view(B, C, H * W)
|
207
208
|
x = x.transpose(-1, -2)
|
208
209
|
else:
|
209
|
-
x =
|
210
|
-
x = x.transpose(-1, -2)
|
210
|
+
x = torch.permute(input_tensor, (0, 2, 3, 1))
|
211
211
|
x = self.norm(x)
|
212
|
+
x = x.view(B, H * W, C)
|
212
213
|
x = self.attention(x, context_tensor)
|
213
|
-
x = x.
|
214
|
-
|
214
|
+
x = x.view(B, H, W, C)
|
215
|
+
residual = torch.permute(residual, (0, 2, 3, 1))
|
215
216
|
x = x + residual
|
217
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
216
218
|
return x
|
217
219
|
|
218
220
|
|
@@ -250,17 +252,17 @@ class FeedForwardBlock2D(nn.Module):
|
|
250
252
|
x = x.view(B, C, H * W)
|
251
253
|
x = x.transpose(-1, -2)
|
252
254
|
else:
|
253
|
-
x =
|
254
|
-
x = x.transpose(-1, -2)
|
255
|
+
x = torch.permute(input_tensor, (0, 2, 3, 1))
|
255
256
|
x = self.norm(x)
|
257
|
+
x = x.view(B, H * W, C)
|
256
258
|
x = self.w1(x)
|
257
259
|
x = self.act(x)
|
258
260
|
x = self.w2(x)
|
259
|
-
|
260
|
-
|
261
|
-
x = x
|
262
|
-
|
263
|
-
return x
|
261
|
+
x = x.view(B, H, W, C)
|
262
|
+
residual = torch.permute(residual, (0, 2, 3, 1))
|
263
|
+
x = x + residual
|
264
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
265
|
+
return x
|
264
266
|
|
265
267
|
|
266
268
|
class TransformerBlock2D(nn.Module):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240908
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=gyyuDD-i83EGkYIzPuqIjPUIC8huFW09RDInbOSOx1c,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=izep
|
|
16
16
|
ai_edge_torch/_convert/fx_passes/canonicalize_pass.py,sha256=8jcKqWzG7p5r3Cu7DXNP-4o4X2bqLaoXY7N6W8QsZXo,1582
|
17
17
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=WKI8V9-V50agkiNVpBFWWp0BEpUfemdENuN1cEaGD-g,2370
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
|
23
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=HXTDEP6_Z0I0s58H6I0yHz9qrkOxptIjKhxywfe8F80,10637
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -96,7 +96,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQ
|
|
96
96
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
97
97
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
98
98
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
99
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
99
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
|
100
100
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
101
101
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
102
102
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
161
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
162
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
163
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/METADATA,sha256=kDesFhNNQZ4QCG572z3cDo3eyiSnnFTUMwzKOmdsrGo,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/RECORD,,
|
File without changes
|
File without changes
|