ai-edge-torch-nightly 0.3.0.dev20240906__py3-none-any.whl → 0.3.0.dev20240908__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +2 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/layers/unet/blocks_2d.py +17 -15
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240908.dist-info}/top_level.txt +0 -0
@@ -150,6 +150,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
150
150
|
# ==== Ops must be NHWC if possible
|
151
151
|
|
152
152
|
|
153
|
+
@layout_sensitive_inputs_getters.register(aten.conv2d)
|
153
154
|
@layout_sensitive_inputs_getters.register(aten.convolution)
|
154
155
|
@layout_sensitive_inputs_getters.register(
|
155
156
|
aten._native_batch_norm_legit_no_training
|
@@ -168,6 +169,7 @@ def _first_arg_getter(node):
|
|
168
169
|
@nhwcable_node_checkers.register(aten.upsample_bilinear2d)
|
169
170
|
@nhwcable_node_checkers.register(aten.upsample_nearest2d)
|
170
171
|
@nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
|
172
|
+
@nhwcable_node_checkers.register(aten.conv2d)
|
171
173
|
@nhwcable_node_checkers.register(aten.convolution)
|
172
174
|
def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
|
173
175
|
can_be = all_layout_sensitive_inputs_are_4d(node)
|
@@ -229,11 +229,12 @@ def transpose_first_arg_rewriter(node: torch.fx.Node):
|
|
229
229
|
node.target = nhwc_op
|
230
230
|
|
231
231
|
|
232
|
+
@rewriters.register(aten.conv2d)
|
232
233
|
@rewriters.register(aten.convolution)
|
233
234
|
def _aten_convolution_rewriter(node: torch.fx.Node):
|
234
235
|
op = node.target
|
235
236
|
|
236
|
-
def conv_nhwc(input, weight, bias, *args, **kwargs):
|
237
|
+
def conv_nhwc(input, weight, bias=None, *args, **kwargs):
|
237
238
|
nonlocal op
|
238
239
|
nhwc_bias = None
|
239
240
|
if bias is not None and len(bias.shape) == 1:
|
@@ -145,14 +145,15 @@ class AttentionBlock2D(nn.Module):
|
|
145
145
|
x = x.view(B, C, H * W)
|
146
146
|
x = x.transpose(-1, -2)
|
147
147
|
else:
|
148
|
-
x =
|
149
|
-
x = x.transpose(-1, -2)
|
148
|
+
x = torch.permute(input_tensor, (0, 2, 3, 1))
|
150
149
|
x = self.norm(x)
|
150
|
+
x = x.view(B, H * W, C)
|
151
151
|
x = x.contiguous() # Prevent BATCH_MATMUL op in converted tflite.
|
152
152
|
x = self.attention(x)
|
153
|
-
x = x.
|
154
|
-
|
153
|
+
x = x.view(B, H, W, C)
|
154
|
+
residual = torch.permute(residual, (0, 2, 3, 1))
|
155
155
|
x = x + residual
|
156
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
156
157
|
return x
|
157
158
|
|
158
159
|
|
@@ -206,13 +207,14 @@ class CrossAttentionBlock2D(nn.Module):
|
|
206
207
|
x = x.view(B, C, H * W)
|
207
208
|
x = x.transpose(-1, -2)
|
208
209
|
else:
|
209
|
-
x =
|
210
|
-
x = x.transpose(-1, -2)
|
210
|
+
x = torch.permute(input_tensor, (0, 2, 3, 1))
|
211
211
|
x = self.norm(x)
|
212
|
+
x = x.view(B, H * W, C)
|
212
213
|
x = self.attention(x, context_tensor)
|
213
|
-
x = x.
|
214
|
-
|
214
|
+
x = x.view(B, H, W, C)
|
215
|
+
residual = torch.permute(residual, (0, 2, 3, 1))
|
215
216
|
x = x + residual
|
217
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
216
218
|
return x
|
217
219
|
|
218
220
|
|
@@ -250,17 +252,17 @@ class FeedForwardBlock2D(nn.Module):
|
|
250
252
|
x = x.view(B, C, H * W)
|
251
253
|
x = x.transpose(-1, -2)
|
252
254
|
else:
|
253
|
-
x =
|
254
|
-
x = x.transpose(-1, -2)
|
255
|
+
x = torch.permute(input_tensor, (0, 2, 3, 1))
|
255
256
|
x = self.norm(x)
|
257
|
+
x = x.view(B, H * W, C)
|
256
258
|
x = self.w1(x)
|
257
259
|
x = self.act(x)
|
258
260
|
x = self.w2(x)
|
259
|
-
|
260
|
-
|
261
|
-
x = x
|
262
|
-
|
263
|
-
return x
|
261
|
+
x = x.view(B, H, W, C)
|
262
|
+
residual = torch.permute(residual, (0, 2, 3, 1))
|
263
|
+
x = x + residual
|
264
|
+
x = torch.permute(x, (0, 3, 1, 2))
|
265
|
+
return x
|
264
266
|
|
265
267
|
|
266
268
|
class TransformerBlock2D(nn.Module):
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240908
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=gyyuDD-i83EGkYIzPuqIjPUIC8huFW09RDInbOSOx1c,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=izep
|
|
16
16
|
ai_edge_torch/_convert/fx_passes/canonicalize_pass.py,sha256=8jcKqWzG7p5r3Cu7DXNP-4o4X2bqLaoXY7N6W8QsZXo,1582
|
17
17
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=WKI8V9-V50agkiNVpBFWWp0BEpUfemdENuN1cEaGD-g,2370
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
|
23
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=HXTDEP6_Z0I0s58H6I0yHz9qrkOxptIjKhxywfe8F80,10637
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -96,7 +96,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQ
|
|
96
96
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
97
97
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
98
98
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
99
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
99
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
|
100
100
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
101
101
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
102
102
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
161
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
162
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
163
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/METADATA,sha256=kDesFhNNQZ4QCG572z3cDo3eyiSnnFTUMwzKOmdsrGo,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240908.dist-info/RECORD,,
|
File without changes
|
File without changes
|