ai-edge-torch-nightly 0.3.0.dev20240906__py3-none-any.whl → 0.3.0.dev20240909__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +2 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +2 -1
- ai_edge_torch/generative/layers/unet/blocks_2d.py +17 -15
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240909.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240909.dist-info}/RECORD +9 -9
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240909.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240909.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240906.dist-info → ai_edge_torch_nightly-0.3.0.dev20240909.dist-info}/top_level.txt +0 -0
| @@ -150,6 +150,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node): | |
| 150 150 | 
             
            # ==== Ops must be NHWC if possible
         | 
| 151 151 |  | 
| 152 152 |  | 
| 153 | 
            +
            @layout_sensitive_inputs_getters.register(aten.conv2d)
         | 
| 153 154 | 
             
            @layout_sensitive_inputs_getters.register(aten.convolution)
         | 
| 154 155 | 
             
            @layout_sensitive_inputs_getters.register(
         | 
| 155 156 | 
             
                aten._native_batch_norm_legit_no_training
         | 
| @@ -168,6 +169,7 @@ def _first_arg_getter(node): | |
| 168 169 | 
             
            @nhwcable_node_checkers.register(aten.upsample_bilinear2d)
         | 
| 169 170 | 
             
            @nhwcable_node_checkers.register(aten.upsample_nearest2d)
         | 
| 170 171 | 
             
            @nhwcable_node_checkers.register(aten._adaptive_avg_pool2d)
         | 
| 172 | 
            +
            @nhwcable_node_checkers.register(aten.conv2d)
         | 
| 171 173 | 
             
            @nhwcable_node_checkers.register(aten.convolution)
         | 
| 172 174 | 
             
            def _all_layout_sensitive_inputs_are_4d_checker(node: Node):
         | 
| 173 175 | 
             
              can_be = all_layout_sensitive_inputs_are_4d(node)
         | 
| @@ -229,11 +229,12 @@ def transpose_first_arg_rewriter(node: torch.fx.Node): | |
| 229 229 | 
             
              node.target = nhwc_op
         | 
| 230 230 |  | 
| 231 231 |  | 
| 232 | 
            +
            @rewriters.register(aten.conv2d)
         | 
| 232 233 | 
             
            @rewriters.register(aten.convolution)
         | 
| 233 234 | 
             
            def _aten_convolution_rewriter(node: torch.fx.Node):
         | 
| 234 235 | 
             
              op = node.target
         | 
| 235 236 |  | 
| 236 | 
            -
              def conv_nhwc(input, weight, bias, *args, **kwargs):
         | 
| 237 | 
            +
              def conv_nhwc(input, weight, bias=None, *args, **kwargs):
         | 
| 237 238 | 
             
                nonlocal op
         | 
| 238 239 | 
             
                nhwc_bias = None
         | 
| 239 240 | 
             
                if bias is not None and len(bias.shape) == 1:
         | 
| @@ -145,14 +145,15 @@ class AttentionBlock2D(nn.Module): | |
| 145 145 | 
             
                  x = x.view(B, C, H * W)
         | 
| 146 146 | 
             
                  x = x.transpose(-1, -2)
         | 
| 147 147 | 
             
                else:
         | 
| 148 | 
            -
                  x =  | 
| 149 | 
            -
                  x = x.transpose(-1, -2)
         | 
| 148 | 
            +
                  x = torch.permute(input_tensor, (0, 2, 3, 1))
         | 
| 150 149 | 
             
                  x = self.norm(x)
         | 
| 150 | 
            +
                  x = x.view(B, H * W, C)
         | 
| 151 151 | 
             
                x = x.contiguous()  # Prevent BATCH_MATMUL op in converted tflite.
         | 
| 152 152 | 
             
                x = self.attention(x)
         | 
| 153 | 
            -
                x = x. | 
| 154 | 
            -
                 | 
| 153 | 
            +
                x = x.view(B, H, W, C)
         | 
| 154 | 
            +
                residual = torch.permute(residual, (0, 2, 3, 1))
         | 
| 155 155 | 
             
                x = x + residual
         | 
| 156 | 
            +
                x = torch.permute(x, (0, 3, 1, 2))
         | 
| 156 157 | 
             
                return x
         | 
| 157 158 |  | 
| 158 159 |  | 
| @@ -206,13 +207,14 @@ class CrossAttentionBlock2D(nn.Module): | |
| 206 207 | 
             
                  x = x.view(B, C, H * W)
         | 
| 207 208 | 
             
                  x = x.transpose(-1, -2)
         | 
| 208 209 | 
             
                else:
         | 
| 209 | 
            -
                  x =  | 
| 210 | 
            -
                  x = x.transpose(-1, -2)
         | 
| 210 | 
            +
                  x = torch.permute(input_tensor, (0, 2, 3, 1))
         | 
| 211 211 | 
             
                  x = self.norm(x)
         | 
| 212 | 
            +
                  x = x.view(B, H * W, C)
         | 
| 212 213 | 
             
                x = self.attention(x, context_tensor)
         | 
| 213 | 
            -
                x = x. | 
| 214 | 
            -
                 | 
| 214 | 
            +
                x = x.view(B, H, W, C)
         | 
| 215 | 
            +
                residual = torch.permute(residual, (0, 2, 3, 1))
         | 
| 215 216 | 
             
                x = x + residual
         | 
| 217 | 
            +
                x = torch.permute(x, (0, 3, 1, 2))
         | 
| 216 218 | 
             
                return x
         | 
| 217 219 |  | 
| 218 220 |  | 
| @@ -250,17 +252,17 @@ class FeedForwardBlock2D(nn.Module): | |
| 250 252 | 
             
                  x = x.view(B, C, H * W)
         | 
| 251 253 | 
             
                  x = x.transpose(-1, -2)
         | 
| 252 254 | 
             
                else:
         | 
| 253 | 
            -
                  x =  | 
| 254 | 
            -
                  x = x.transpose(-1, -2)
         | 
| 255 | 
            +
                  x = torch.permute(input_tensor, (0, 2, 3, 1))
         | 
| 255 256 | 
             
                  x = self.norm(x)
         | 
| 257 | 
            +
                  x = x.view(B, H * W, C)
         | 
| 256 258 | 
             
                x = self.w1(x)
         | 
| 257 259 | 
             
                x = self.act(x)
         | 
| 258 260 | 
             
                x = self.w2(x)
         | 
| 259 | 
            -
             | 
| 260 | 
            -
                 | 
| 261 | 
            -
                x = x | 
| 262 | 
            -
             | 
| 263 | 
            -
                return x | 
| 261 | 
            +
                x = x.view(B, H, W, C)
         | 
| 262 | 
            +
                residual = torch.permute(residual, (0, 2, 3, 1))
         | 
| 263 | 
            +
                x = x + residual
         | 
| 264 | 
            +
                x = torch.permute(x, (0, 3, 1, 2))
         | 
| 265 | 
            +
                return x
         | 
| 264 266 |  | 
| 265 267 |  | 
| 266 268 | 
             
            class TransformerBlock2D(nn.Module):
         | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: ai-edge-torch-nightly
         | 
| 3 | 
            -
            Version: 0.3.0. | 
| 3 | 
            +
            Version: 0.3.0.dev20240909
         | 
| 4 4 | 
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         | 
| 5 5 | 
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         | 
| 6 6 | 
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         | 
| @@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116 | |
| 2 2 | 
             
            ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
         | 
| 3 3 | 
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         | 
| 4 4 | 
             
            ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
         | 
| 5 | 
            -
            ai_edge_torch/version.py,sha256= | 
| 5 | 
            +
            ai_edge_torch/version.py,sha256=r0y6crIySNGhJqtljkzyHxb1XMvLji2VLajLfUjW8b4,706
         | 
| 6 6 | 
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 7 7 | 
             
            ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
         | 
| 8 8 | 
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         | 
| @@ -16,9 +16,9 @@ ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=izep | |
| 16 16 | 
             
            ai_edge_torch/_convert/fx_passes/canonicalize_pass.py,sha256=8jcKqWzG7p5r3Cu7DXNP-4o4X2bqLaoXY7N6W8QsZXo,1582
         | 
| 17 17 | 
             
            ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=WKI8V9-V50agkiNVpBFWWp0BEpUfemdENuN1cEaGD-g,2370
         | 
| 18 18 | 
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
         | 
| 19 | 
            -
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256= | 
| 19 | 
            +
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=hDsl9AHzmyuSWsdHOSO114l4nBUgUdAOUWafMTipMgA,7629
         | 
| 20 20 | 
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
         | 
| 21 | 
            -
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256= | 
| 21 | 
            +
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=NW37V6QYdPOZOVhqLcmssVk-VAeO4ECk_CrbEBh4B0E,12740
         | 
| 22 22 | 
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
         | 
| 23 23 | 
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=HXTDEP6_Z0I0s58H6I0yHz9qrkOxptIjKhxywfe8F80,10637
         | 
| 24 24 | 
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
         | 
| @@ -96,7 +96,7 @@ ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQ | |
| 96 96 | 
             
            ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
         | 
| 97 97 | 
             
            ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
         | 
| 98 98 | 
             
            ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 99 | 
            -
            ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256= | 
| 99 | 
            +
            ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
         | 
| 100 100 | 
             
            ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
         | 
| 101 101 | 
             
            ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
         | 
| 102 102 | 
             
            ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| @@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 | |
| 161 161 | 
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         | 
| 162 162 | 
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         | 
| 163 163 | 
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         | 
| 164 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 165 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 166 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 167 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 168 | 
            -
            ai_edge_torch_nightly-0.3.0. | 
| 164 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         | 
| 165 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/METADATA,sha256=s7SAIUvFciy8peNKMHvyhoNQWYx67Jerz4foeV7KiE0,1859
         | 
| 166 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
         | 
| 167 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         | 
| 168 | 
            +
            ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/RECORD,,
         | 
| 
            File without changes
         | 
| 
            File without changes
         |