ai-edge-torch-nightly 0.3.0.dev20250121__py3-none-any.whl → 0.3.0.dev20250123__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/__init__.py +1 -1
 - ai_edge_torch/_convert/conversion.py +6 -10
 - ai_edge_torch/_convert/fx_passes/__init__.py +1 -1
 - ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py +6 -3
 - ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +9 -11
 - ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +3 -3
 - ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -0
 - ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py +33 -0
 - ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -0
 - ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -3
 - ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py +4 -4
 - ai_edge_torch/_convert/test/test_convert.py +2 -2
 - ai_edge_torch/fx_infra/__init__.py +32 -0
 - ai_edge_torch/fx_infra/_canonicalize_pass.py +27 -0
 - ai_edge_torch/fx_infra/_safe_run_decompositions.py +57 -0
 - ai_edge_torch/fx_infra/decomp.py +80 -0
 - ai_edge_torch/fx_infra/graph_utils.py +42 -0
 - ai_edge_torch/{fx_pass_base.py → fx_infra/pass_base.py} +0 -28
 - ai_edge_torch/generative/fx_passes/__init__.py +3 -3
 - ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +3 -3
 - ai_edge_torch/hlfb/mark_pattern/__init__.py +0 -1
 - ai_edge_torch/hlfb/mark_pattern/fx_utils.py +5 -21
 - ai_edge_torch/hlfb/mark_pattern/pattern.py +20 -11
 - ai_edge_torch/hlfb/test/test_mark_pattern.py +18 -15
 - ai_edge_torch/odml_torch/_torch_future.py +0 -27
 - ai_edge_torch/odml_torch/export.py +6 -8
 - ai_edge_torch/odml_torch/lowerings/__init__.py +1 -1
 - ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +65 -0
 - ai_edge_torch/version.py +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/METADATA +1 -1
 - {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/RECORD +34 -28
 - ai_edge_torch/odml_torch/lowerings/decomp.py +0 -69
 - {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/LICENSE +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/WHEEL +0 -0
 - {ai_edge_torch_nightly-0.3.0.dev20250121.dist-info → ai_edge_torch_nightly-0.3.0.dev20250123.dist-info}/top_level.txt +0 -0
 
    
        ai_edge_torch/__init__.py
    CHANGED
    
    | 
         @@ -12,7 +12,6 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
             
     | 
| 
       16 
15 
     | 
    
         
             
            from ai_edge_torch._config import config
         
     | 
| 
       17 
16 
     | 
    
         
             
            from ai_edge_torch._convert.converter import convert
         
     | 
| 
       18 
17 
     | 
    
         
             
            from ai_edge_torch._convert.converter import signature
         
     | 
| 
         @@ -20,6 +19,7 @@ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io 
     | 
|
| 
       20 
19 
     | 
    
         
             
            from ai_edge_torch.model import Model
         
     | 
| 
       21 
20 
     | 
    
         
             
            from ai_edge_torch.version import __version__
         
     | 
| 
       22 
21 
     | 
    
         | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
       23 
23 
     | 
    
         
             
            def load(path: str) -> Model:
         
     | 
| 
       24 
24 
     | 
    
         
             
              """Imports an ai_edge_torch model from disk.
         
     | 
| 
       25 
25 
     | 
    
         | 
| 
         @@ -17,7 +17,7 @@ import logging 
     | 
|
| 
       17 
17 
     | 
    
         
             
            from typing import Any, Literal, Optional, Union
         
     | 
| 
       18 
18 
     | 
    
         | 
| 
       19 
19 
     | 
    
         
             
            import ai_edge_torch
         
     | 
| 
       20 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 20 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       21 
21 
     | 
    
         
             
            from ai_edge_torch import lowertools
         
     | 
| 
       22 
22 
     | 
    
         
             
            from ai_edge_torch import model
         
     | 
| 
       23 
23 
     | 
    
         
             
            from ai_edge_torch._convert import fx_passes
         
     | 
| 
         @@ -53,7 +53,7 @@ def _run_convert_passes( 
     | 
|
| 
       53 
53 
     | 
    
         
             
                    fx_passes.CanonicalizePass(),
         
     | 
| 
       54 
54 
     | 
    
         
             
                ]
         
     | 
| 
       55 
55 
     | 
    
         | 
| 
       56 
     | 
    
         
            -
              exported_program =  
     | 
| 
      
 56 
     | 
    
         
            +
              exported_program = fx_infra.run_passes(exported_program, passes)
         
     | 
| 
       57 
57 
     | 
    
         
             
              return exported_program
         
     | 
| 
       58 
58 
     | 
    
         | 
| 
       59 
59 
     | 
    
         | 
| 
         @@ -125,14 +125,10 @@ def convert_signatures( 
     | 
|
| 
       125 
125 
     | 
    
         
             
                else:
         
     | 
| 
       126 
126 
     | 
    
         
             
                  exported_program = torch.export.export(**kwargs, strict=True)
         
     | 
| 
       127 
127 
     | 
    
         | 
| 
       128 
     | 
    
         
            -
                 
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
             
     | 
| 
       132 
     | 
    
         
            -
                  # implementation of torch.export changes.
         
     | 
| 
       133 
     | 
    
         
            -
                  exported_program = exported_program.run_decompositions(
         
     | 
| 
       134 
     | 
    
         
            -
                      torch._decomp._decomp_table_to_post_autograd_aten()
         
     | 
| 
       135 
     | 
    
         
            -
                  )
         
     | 
| 
      
 128 
     | 
    
         
            +
                exported_program = fx_infra.safe_run_decompositions(
         
     | 
| 
      
 129 
     | 
    
         
            +
                    exported_program,
         
     | 
| 
      
 130 
     | 
    
         
            +
                    fx_infra.decomp.pre_convert_decomp(),
         
     | 
| 
      
 131 
     | 
    
         
            +
                )
         
     | 
| 
       136 
132 
     | 
    
         
             
                return exported_program
         
     | 
| 
       137 
133 
     | 
    
         | 
| 
       138 
134 
     | 
    
         
             
              exported_programs: torch.export.ExportedProgram = [
         
     | 
| 
         @@ -20,4 +20,4 @@ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import Bu 
     | 
|
| 
       20 
20 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass
         
     | 
| 
       21 
21 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass
         
     | 
| 
       22 
22 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.remove_non_user_outputs_pass import RemoveNonUserOutputsPass
         
     | 
| 
       23 
     | 
    
         
            -
            from ai_edge_torch. 
     | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import CanonicalizePass
         
     | 
| 
         @@ -14,7 +14,7 @@ 
     | 
|
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
       16 
16 
     | 
    
         
             
            from typing import Any, Callable
         
     | 
| 
       17 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 17 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       18 
18 
     | 
    
         
             
            from ai_edge_torch import lowertools
         
     | 
| 
       19 
19 
     | 
    
         
             
            import torch
         
     | 
| 
       20 
20 
     | 
    
         
             
            import torch.utils._pytree as pytree
         
     | 
| 
         @@ -25,6 +25,9 @@ _composite_builders: dict[ 
     | 
|
| 
       25 
25 
     | 
    
         | 
| 
       26 
26 
     | 
    
         | 
| 
       27 
27 
     | 
    
         
             
            def _register_composite_builder(op):
         
     | 
| 
      
 28 
     | 
    
         
            +
              # Remove op from pre_convert_decomp to keep this in the decomposed graph.
         
     | 
| 
      
 29 
     | 
    
         
            +
              fx_infra.decomp.remove_pre_convert_decomp(op)
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
       28 
31 
     | 
    
         
             
              def inner(func):
         
     | 
| 
       29 
32 
     | 
    
         
             
                if isinstance(op, torch._ops.OpOverloadPacket):
         
     | 
| 
       30 
33 
     | 
    
         
             
                  for overload in op.overloads():
         
     | 
| 
         @@ -276,7 +279,7 @@ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node): 
     | 
|
| 
       276 
279 
     | 
    
         
             
              node.target = embedding
         
     | 
| 
       277 
280 
     | 
    
         | 
| 
       278 
281 
     | 
    
         | 
| 
       279 
     | 
    
         
            -
            class BuildAtenCompositePass( 
     | 
| 
      
 282 
     | 
    
         
            +
            class BuildAtenCompositePass(fx_infra.PassBase):
         
     | 
| 
       280 
283 
     | 
    
         | 
| 
       281 
284 
     | 
    
         
             
              def call(self, graph_module: torch.fx.GraphModule):
         
     | 
| 
       282 
285 
     | 
    
         
             
                for node in graph_module.graph.nodes:
         
     | 
| 
         @@ -285,4 +288,4 @@ class BuildAtenCompositePass(fx_pass_base.PassBase): 
     | 
|
| 
       285 
288 
     | 
    
         | 
| 
       286 
289 
     | 
    
         
             
                graph_module.graph.lint()
         
     | 
| 
       287 
290 
     | 
    
         
             
                graph_module.recompile()
         
     | 
| 
       288 
     | 
    
         
            -
                return  
     | 
| 
      
 291 
     | 
    
         
            +
                return fx_infra.PassResult(graph_module, True)
         
     | 
| 
         @@ -16,7 +16,7 @@ 
     | 
|
| 
       16 
16 
     | 
    
         | 
| 
       17 
17 
     | 
    
         
             
            import functools
         
     | 
| 
       18 
18 
     | 
    
         | 
| 
       19 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 19 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       20 
20 
     | 
    
         
             
            from ai_edge_torch.hlfb import mark_pattern
         
     | 
| 
       21 
21 
     | 
    
         
             
            from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
         
     | 
| 
       22 
22 
     | 
    
         
             
            import torch
         
     | 
| 
         @@ -41,7 +41,7 @@ def _get_upsample_bilinear2d_pattern(): 
     | 
|
| 
       41 
41 
     | 
    
         
             
                      x, scale_factor=2, mode="bilinear", align_corners=False
         
     | 
| 
       42 
42 
     | 
    
         
             
                  ),
         
     | 
| 
       43 
43 
     | 
    
         
             
                  export_args=(torch.rand(1, 3, 100, 100),),
         
     | 
| 
       44 
     | 
    
         
            -
                   
     | 
| 
      
 44 
     | 
    
         
            +
                  extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
         
     | 
| 
       45 
45 
     | 
    
         
             
              )
         
     | 
| 
       46 
46 
     | 
    
         | 
| 
       47 
47 
     | 
    
         
             
              @pattern.register_attr_builder
         
     | 
| 
         @@ -65,7 +65,7 @@ def _get_upsample_bilinear2d_align_corners_pattern(): 
     | 
|
| 
       65 
65 
     | 
    
         
             
                      x, scale_factor=2, mode="bilinear", align_corners=True
         
     | 
| 
       66 
66 
     | 
    
         
             
                  ),
         
     | 
| 
       67 
67 
     | 
    
         
             
                  export_args=(torch.rand(1, 3, 100, 100),),
         
     | 
| 
       68 
     | 
    
         
            -
                   
     | 
| 
      
 68 
     | 
    
         
            +
                  extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
         
     | 
| 
       69 
69 
     | 
    
         
             
              )
         
     | 
| 
       70 
70 
     | 
    
         | 
| 
       71 
71 
     | 
    
         
             
              @pattern.register_attr_builder
         
     | 
| 
         @@ -89,7 +89,7 @@ def _get_interpolate_nearest2d_pattern(): 
     | 
|
| 
       89 
89 
     | 
    
         
             
                      x, scale_factor=2, mode="nearest"
         
     | 
| 
       90 
90 
     | 
    
         
             
                  ),
         
     | 
| 
       91 
91 
     | 
    
         
             
                  export_args=(torch.rand(1, 3, 100, 100),),
         
     | 
| 
       92 
     | 
    
         
            -
                   
     | 
| 
      
 92 
     | 
    
         
            +
                  extra_decomp_table=_INTERPOLATE_DECOMPOSITIONS,
         
     | 
| 
       93 
93 
     | 
    
         
             
              )
         
     | 
| 
       94 
94 
     | 
    
         | 
| 
       95 
95 
     | 
    
         
             
              @pattern.register_attr_builder
         
     | 
| 
         @@ -104,7 +104,7 @@ def _get_interpolate_nearest2d_pattern(): 
     | 
|
| 
       104 
104 
     | 
    
         
             
              return pattern
         
     | 
| 
       105 
105 
     | 
    
         | 
| 
       106 
106 
     | 
    
         | 
| 
       107 
     | 
    
         
            -
            class BuildInterpolateCompositePass( 
     | 
| 
      
 107 
     | 
    
         
            +
            class BuildInterpolateCompositePass(fx_infra.ExportedProgramPassBase):
         
     | 
| 
       108 
108 
     | 
    
         | 
| 
       109 
109 
     | 
    
         
             
              def __init__(self):
         
     | 
| 
       110 
110 
     | 
    
         
             
                super().__init__()
         
     | 
| 
         @@ -115,11 +115,9 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase): 
     | 
|
| 
       115 
115 
     | 
    
         
             
                ]
         
     | 
| 
       116 
116 
     | 
    
         | 
| 
       117 
117 
     | 
    
         
             
              def call(self, exported_program: torch.export.ExportedProgram):
         
     | 
| 
       118 
     | 
    
         
            -
                exported_program =  
     | 
| 
       119 
     | 
    
         
            -
                    exported_program, 
     | 
| 
       120 
     | 
    
         
            -
             
     | 
| 
       121 
     | 
    
         
            -
                exported_program = exported_program.run_decompositions(
         
     | 
| 
       122 
     | 
    
         
            -
                    _INTERPOLATE_DECOMPOSITIONS
         
     | 
| 
      
 118 
     | 
    
         
            +
                exported_program = fx_infra.safe_run_decompositions(
         
     | 
| 
      
 119 
     | 
    
         
            +
                    exported_program,
         
     | 
| 
      
 120 
     | 
    
         
            +
                    _INTERPOLATE_DECOMPOSITIONS,
         
     | 
| 
       123 
121 
     | 
    
         
             
                )
         
     | 
| 
       124 
122 
     | 
    
         | 
| 
       125 
123 
     | 
    
         
             
                graph_module = exported_program.graph_module
         
     | 
| 
         @@ -128,4 +126,4 @@ class BuildInterpolateCompositePass(fx_pass_base.ExportedProgramPassBase): 
     | 
|
| 
       128 
126 
     | 
    
         | 
| 
       129 
127 
     | 
    
         
             
                graph_module.graph.lint()
         
     | 
| 
       130 
128 
     | 
    
         
             
                graph_module.recompile()
         
     | 
| 
       131 
     | 
    
         
            -
                return  
     | 
| 
      
 129 
     | 
    
         
            +
                return fx_infra.ExportedProgramPassResult(exported_program, True)
         
     | 
| 
         @@ -13,7 +13,7 @@ 
     | 
|
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
       16 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 16 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       17 
17 
     | 
    
         
             
            from ai_edge_torch import lowertools
         
     | 
| 
       18 
18 
     | 
    
         
             
            import torch
         
     | 
| 
       19 
19 
     | 
    
         
             
            import torch.utils._pytree as pytree
         
     | 
| 
         @@ -61,7 +61,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule): 
     | 
|
| 
       61 
61 
     | 
    
         
             
              node.target = debuginfo_writer
         
     | 
| 
       62 
62 
     | 
    
         | 
| 
       63 
63 
     | 
    
         | 
| 
       64 
     | 
    
         
            -
            class InjectMlirDebuginfoPass( 
     | 
| 
      
 64 
     | 
    
         
            +
            class InjectMlirDebuginfoPass(fx_infra.PassBase):
         
     | 
| 
       65 
65 
     | 
    
         
             
              """DEPRECATED: Debuginfo is injected automatically by odml_torch."""
         
     | 
| 
       66 
66 
     | 
    
         | 
| 
       67 
67 
     | 
    
         
             
              def call(self, graph_module: torch.fx.GraphModule):
         
     | 
| 
         @@ -70,4 +70,4 @@ class InjectMlirDebuginfoPass(fx_pass_base.PassBase): 
     | 
|
| 
       70 
70 
     | 
    
         | 
| 
       71 
71 
     | 
    
         
             
                graph_module.graph.lint()
         
     | 
| 
       72 
72 
     | 
    
         
             
                graph_module.recompile()
         
     | 
| 
       73 
     | 
    
         
            -
                return  
     | 
| 
      
 73 
     | 
    
         
            +
                return fx_infra.PassResult(graph_module, True)
         
     | 
| 
         @@ -13,4 +13,5 @@ 
     | 
|
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
      
 16 
     | 
    
         
            +
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import _decomp_registry
         
     | 
| 
       16 
17 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass.pass_body import OptimizeLayoutTransposesPass  # NOQA
         
     | 
| 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2025 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            """Remove decompositions for ops to keep in layout optimization."""
         
     | 
| 
      
 16 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
      
 17 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            __all__ = []
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            aten = torch.ops.aten
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            _OPS_TO_KEEP = [
         
     | 
| 
      
 24 
     | 
    
         
            +
                aten.conv2d,
         
     | 
| 
      
 25 
     | 
    
         
            +
                aten.max_pool2d,
         
     | 
| 
      
 26 
     | 
    
         
            +
                aten._softmax.default,
         
     | 
| 
      
 27 
     | 
    
         
            +
                aten.group_norm.default,
         
     | 
| 
      
 28 
     | 
    
         
            +
                aten.native_group_norm.default,
         
     | 
| 
      
 29 
     | 
    
         
            +
                aten.reflection_pad2d.default,
         
     | 
| 
      
 30 
     | 
    
         
            +
            ]
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            for op in _OPS_TO_KEEP:
         
     | 
| 
      
 33 
     | 
    
         
            +
              fx_infra.decomp.remove_pre_convert_decomp(op)
         
     | 
| 
         @@ -18,7 +18,7 @@ import operator 
     | 
|
| 
       18 
18 
     | 
    
         
             
            import os
         
     | 
| 
       19 
19 
     | 
    
         
             
            from typing import Union
         
     | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 21 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       22 
22 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check  # NOQA
         
     | 
| 
       23 
23 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark  # NOQA
         
     | 
| 
       24 
24 
     | 
    
         
             
            from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners  # NOQA
         
     | 
| 
         @@ -30,7 +30,7 @@ import torch.ao.quantization.quantize_pt2e 
     | 
|
| 
       30 
30 
     | 
    
         
             
            TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
         
     | 
| 
       31 
31 
     | 
    
         | 
| 
       32 
32 
     | 
    
         | 
| 
       33 
     | 
    
         
            -
            class OptimizeLayoutTransposesPass( 
     | 
| 
      
 33 
     | 
    
         
            +
            class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
         
     | 
| 
       34 
34 
     | 
    
         | 
| 
       35 
35 
     | 
    
         
             
              def get_source_meta(self, node: torch.fx.Node):
         
     | 
| 
       36 
36 
     | 
    
         
             
                keys = ["stack_trace", "nn_module_stack", "source_fn_stack", "from_node"]
         
     | 
| 
         @@ -300,4 +300,4 @@ class OptimizeLayoutTransposesPass(fx_pass_base.ExportedProgramPassBase): 
     | 
|
| 
       300 
300 
     | 
    
         
             
                # Mark const node again for debugging
         
     | 
| 
       301 
301 
     | 
    
         
             
                self.mark_const_nodes(exported_program)
         
     | 
| 
       302 
302 
     | 
    
         | 
| 
       303 
     | 
    
         
            -
                return  
     | 
| 
      
 303 
     | 
    
         
            +
                return fx_infra.ExportedProgramPassResult(exported_program, True)
         
     | 
| 
         @@ -15,11 +15,11 @@ 
     | 
|
| 
       15 
15 
     | 
    
         
             
            """Pass to remove all non user outputs from exported program."""
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
       17 
17 
     | 
    
         | 
| 
       18 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 18 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       19 
19 
     | 
    
         
             
            import torch
         
     | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
21 
     | 
    
         | 
| 
       22 
     | 
    
         
            -
            class RemoveNonUserOutputsPass( 
     | 
| 
      
 22 
     | 
    
         
            +
            class RemoveNonUserOutputsPass(fx_infra.ExportedProgramPassBase):
         
     | 
| 
       23 
23 
     | 
    
         
             
              """This pass removes all non user outputs from the exported program's output.
         
     | 
| 
       24 
24 
     | 
    
         | 
| 
       25 
25 
     | 
    
         
             
              The FX graph may output more tensors/data than what user's original model
         
     | 
| 
         @@ -47,6 +47,6 @@ class RemoveNonUserOutputsPass(fx_pass_base.ExportedProgramPassBase): 
     | 
|
| 
       47 
47 
     | 
    
         
             
                  node.args = (tuple(new_outputs),)
         
     | 
| 
       48 
48 
     | 
    
         
             
                  exported_program.graph_signature.output_specs = new_output_specs
         
     | 
| 
       49 
49 
     | 
    
         | 
| 
       50 
     | 
    
         
            -
                exported_program. 
     | 
| 
      
 50 
     | 
    
         
            +
                exported_program.graph.eliminate_dead_code()
         
     | 
| 
       51 
51 
     | 
    
         
             
                exported_program.graph_module.recompile()
         
     | 
| 
       52 
     | 
    
         
            -
                return  
     | 
| 
      
 52 
     | 
    
         
            +
                return fx_infra.ExportedProgramPassResult(exported_program, True)
         
     | 
| 
         @@ -511,7 +511,7 @@ class TestConvert(googletest.TestCase): 
     | 
|
| 
       511 
511 
     | 
    
         
             
                # Step 1: export resnet18
         
     | 
| 
       512 
512 
     | 
    
         
             
                args = (torch.randn(1, 3, 224, 224),)
         
     | 
| 
       513 
513 
     | 
    
         
             
                m = torchvision.models.resnet18().eval()
         
     | 
| 
       514 
     | 
    
         
            -
                m = torch. 
     | 
| 
      
 514 
     | 
    
         
            +
                m = torch.export.export_for_training(m, args).module()
         
     | 
| 
       515 
515 
     | 
    
         | 
| 
       516 
516 
     | 
    
         
             
                # Step 2: Insert observers or fake quantize modules
         
     | 
| 
       517 
517 
     | 
    
         
             
                quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
         
     | 
| 
         @@ -533,7 +533,7 @@ class TestConvert(googletest.TestCase): 
     | 
|
| 
       533 
533 
     | 
    
         
             
                # Step 1: export resnet18
         
     | 
| 
       534 
534 
     | 
    
         
             
                args = (torch.randn(1, 3, 224, 224),)
         
     | 
| 
       535 
535 
     | 
    
         
             
                m = torchvision.models.resnet18().eval()
         
     | 
| 
       536 
     | 
    
         
            -
                m = torch. 
     | 
| 
      
 536 
     | 
    
         
            +
                m = torch.export.export_for_training(m, args).module()
         
     | 
| 
       537 
537 
     | 
    
         | 
| 
       538 
538 
     | 
    
         
             
                # Step 2: Insert observers or fake quantize modules
         
     | 
| 
       539 
539 
     | 
    
         
             
                quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
         
     | 
| 
         @@ -0,0 +1,32 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2025 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import _canonicalize_pass
         
     | 
| 
      
 17 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import _safe_run_decompositions
         
     | 
| 
      
 18 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import decomp
         
     | 
| 
      
 19 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import graph_utils
         
     | 
| 
      
 20 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import pass_base
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
            PassBase = pass_base.PassBase
         
     | 
| 
      
 24 
     | 
    
         
            +
            PassResult = pass_base.PassResult
         
     | 
| 
      
 25 
     | 
    
         
            +
            FxPassBase = pass_base.FxPassBase
         
     | 
| 
      
 26 
     | 
    
         
            +
            FxPassResult = pass_base.FxPassResult
         
     | 
| 
      
 27 
     | 
    
         
            +
            ExportedProgramPassBase = pass_base.ExportedProgramPassBase
         
     | 
| 
      
 28 
     | 
    
         
            +
            ExportedProgramPassResult = pass_base.ExportedProgramPassResult
         
     | 
| 
      
 29 
     | 
    
         
            +
            run_passes = pass_base.run_passes
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
            CanonicalizePass = _canonicalize_pass.CanonicalizePass
         
     | 
| 
      
 32 
     | 
    
         
            +
            safe_run_decompositions = _safe_run_decompositions.safe_run_decompositions
         
     | 
| 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2025 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import _safe_run_decompositions
         
     | 
| 
      
 16 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import pass_base
         
     | 
| 
      
 17 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
            class CanonicalizePass(pass_base.ExportedProgramPassBase):
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
              def call(self, exported_program: torch.export.ExportedProgram):
         
     | 
| 
      
 23 
     | 
    
         
            +
                exported_program = _safe_run_decompositions.safe_run_decompositions(
         
     | 
| 
      
 24 
     | 
    
         
            +
                    exported_program, {}
         
     | 
| 
      
 25 
     | 
    
         
            +
                )
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                return pass_base.ExportedProgramPassResult(exported_program, True)
         
     | 
| 
         @@ -0,0 +1,57 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
         
     | 
| 
      
 16 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            # A dummy decomp table for running ExportedProgram.run_decompositions without
         
     | 
| 
      
 20 
     | 
    
         
            +
            # any op decompositions but just aot_export_module. Due to the check in
         
     | 
| 
      
 21 
     | 
    
         
            +
            # run_decompositions, if None or an empty dict is passed as decomp_table,
         
     | 
| 
      
 22 
     | 
    
         
            +
            # it will run the default aten-coreaten decompositions. Therefore a non-empty
         
     | 
| 
      
 23 
     | 
    
         
            +
            # dummy decomp table is needed.
         
     | 
| 
      
 24 
     | 
    
         
            +
            # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
         
     | 
| 
      
 25 
     | 
    
         
            +
            _DUMMY_DECOMP_TABLE = {
         
     | 
| 
      
 26 
     | 
    
         
            +
                torch._ops.OperatorBase(): lambda: None,
         
     | 
| 
      
 27 
     | 
    
         
            +
            }
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
            def safe_run_decompositions(exported_program, decomp_table=None):
         
     | 
| 
      
 31 
     | 
    
         
            +
              """Wrapper for ExportedProgram.run_decompositions to handle unexpected export behavior."""
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
              if decomp_table is not None and not decomp_table:
         
     | 
| 
      
 34 
     | 
    
         
            +
                # Empty decomp table means no op decompositions. Use dummy decomp table
         
     | 
| 
      
 35 
     | 
    
         
            +
                # instead for backward compatibility.
         
     | 
| 
      
 36 
     | 
    
         
            +
                decomp_table = _DUMMY_DECOMP_TABLE
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
              for node in exported_program.graph.nodes:
         
     | 
| 
      
 39 
     | 
    
         
            +
                if node.target == torch.ops.aten.view.default:
         
     | 
| 
      
 40 
     | 
    
         
            +
                  # Passes or torch.export may generate aten.view nodes not respecting the
         
     | 
| 
      
 41 
     | 
    
         
            +
                  # tensor memory format. Changes all the aten.view to torch.reshape
         
     | 
| 
      
 42 
     | 
    
         
            +
                  # for retracing. If the input memory format is already contiguous,
         
     | 
| 
      
 43 
     | 
    
         
            +
                  # retracing in run_decomposition below would decompose torch.reshape
         
     | 
| 
      
 44 
     | 
    
         
            +
                  # back to one aten.view.
         
     | 
| 
      
 45 
     | 
    
         
            +
                  node.target = lambda self, size: torch.reshape(self.contiguous(), size)
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
              exported_program = exported_program.run_decompositions(decomp_table)
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
              if hasattr(torch.ops.aten, "_assert_tensor_metadata"):
         
     | 
| 
      
 50 
     | 
    
         
            +
                for node in exported_program.graph.nodes:
         
     | 
| 
      
 51 
     | 
    
         
            +
                  if node.target == torch.ops.aten._assert_tensor_metadata.default:
         
     | 
| 
      
 52 
     | 
    
         
            +
                    exported_program.graph.erase_node(node)
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
              exported_program.graph.eliminate_dead_code()
         
     | 
| 
      
 55 
     | 
    
         
            +
              exported_program.graph_module.recompile()
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
              return exported_program
         
     | 
| 
         @@ -0,0 +1,80 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2025 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            """ExportedProgram.run_decompositions wrapper to handle unexpected export behavior."""
         
     | 
| 
      
 16 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            # Decompositions to be run after torch.export before conversion and any
         
     | 
| 
      
 19 
     | 
    
         
            +
            # passes. Remove ops from this decomposition table if they need to be preserved
         
     | 
| 
      
 20 
     | 
    
         
            +
            # in passes.
         
     | 
| 
      
 21 
     | 
    
         
            +
            _pre_convert_decomp = torch._decomp.core_aten_decompositions().copy()
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            # Decompositions to be run after conversion before odml_torch passes and
         
     | 
| 
      
 25 
     | 
    
         
            +
            # lowerings.
         
     | 
| 
      
 26 
     | 
    
         
            +
            _pre_lower_decomp = torch._decomp.core_aten_decompositions().copy()
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
            def _get_ops(op):
         
     | 
| 
      
 30 
     | 
    
         
            +
              if isinstance(op, torch._ops.OpOverloadPacket):
         
     | 
| 
      
 31 
     | 
    
         
            +
                return [getattr(op, overload) for overload in op.overloads()]
         
     | 
| 
      
 32 
     | 
    
         
            +
              else:
         
     | 
| 
      
 33 
     | 
    
         
            +
                return [op]
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
            def pre_convert_decomp():
         
     | 
| 
      
 37 
     | 
    
         
            +
              """Decompositions to be run after torch.export before conversion and any passes."""
         
     | 
| 
      
 38 
     | 
    
         
            +
              return _pre_convert_decomp.copy()
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
            def pre_lower_decomp():
         
     | 
| 
      
 42 
     | 
    
         
            +
              """Decompositions to be run after conversion before odml_torch passes and lowerings."""
         
     | 
| 
      
 43 
     | 
    
         
            +
              return _pre_lower_decomp.copy()
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
            def remove_pre_lower_decomp(op):
         
     | 
| 
      
 47 
     | 
    
         
            +
              # Also remove from pre_convert_decomp which always run before pre_lower_
         
     | 
| 
      
 48 
     | 
    
         
            +
              # decomp.
         
     | 
| 
      
 49 
     | 
    
         
            +
              remove_pre_convert_decomp(op)
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
              for op_ in _get_ops(op):
         
     | 
| 
      
 52 
     | 
    
         
            +
                _pre_lower_decomp.pop(op_, None)
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
            def remove_pre_convert_decomp(op):
         
     | 
| 
      
 56 
     | 
    
         
            +
              for op_ in _get_ops(op):
         
     | 
| 
      
 57 
     | 
    
         
            +
                _pre_convert_decomp.pop(op_, None)
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
            def add_pre_convert_decomp(op, decomp):
         
     | 
| 
      
 61 
     | 
    
         
            +
              # Also add decomp to pre_lower_decomp which runs after pre_convert_decomp.
         
     | 
| 
      
 62 
     | 
    
         
            +
              add_pre_lower_decomp(op, decomp)
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
              for op_ in _get_ops(op):
         
     | 
| 
      
 65 
     | 
    
         
            +
                _pre_convert_decomp[op_] = decomp
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
            def add_pre_lower_decomp(op, decomp):
         
     | 
| 
      
 69 
     | 
    
         
            +
              for op_ in _get_ops(op):
         
     | 
| 
      
 70 
     | 
    
         
            +
                _pre_lower_decomp[op_] = decomp
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
            def update_pre_convert_decomp(decomps):
         
     | 
| 
      
 74 
     | 
    
         
            +
              for op, decomp in decomps.items():
         
     | 
| 
      
 75 
     | 
    
         
            +
                add_pre_convert_decomp(op, decomp)
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
            def update_pre_lower_decomp(decomps):
         
     | 
| 
      
 79 
     | 
    
         
            +
              for op, decomp in decomps.items():
         
     | 
| 
      
 80 
     | 
    
         
            +
                add_pre_lower_decomp(op, decomp)
         
     | 
| 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            """FX graph utilities."""
         
     | 
| 
      
 16 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            def remove_dangling_args(graph_module: torch.fx.GraphModule):
         
     | 
| 
      
 20 
     | 
    
         
            +
              """Removes dangling args from the graph."""
         
     | 
| 
      
 21 
     | 
    
         
            +
              for node in graph_module.graph.nodes:
         
     | 
| 
      
 22 
     | 
    
         
            +
                if node.op == "placeholder" and not node.users:
         
     | 
| 
      
 23 
     | 
    
         
            +
                  graph_module.graph.erase_node(node)
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
              graph_module.graph.lint()
         
     | 
| 
      
 26 
     | 
    
         
            +
              graph_module.recompile()
         
     | 
| 
      
 27 
     | 
    
         
            +
              return graph_module
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
            def remove_assert_tensor_metadata_nodes(graph_module: torch.fx.GraphModule):
         
     | 
| 
      
 31 
     | 
    
         
            +
              """Removes aten._assert_tensor_metadata nodes from the graph.
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
              This op is inserted by torch.export to check tensor metadata on custom ops. It
         
     | 
| 
      
 34 
     | 
    
         
            +
              can break patten matching and lowering.
         
     | 
| 
      
 35 
     | 
    
         
            +
              """
         
     | 
| 
      
 36 
     | 
    
         
            +
              for node in graph_module.graph.nodes:
         
     | 
| 
      
 37 
     | 
    
         
            +
                if node.target == torch.ops.aten._assert_tensor_metadata.default:
         
     | 
| 
      
 38 
     | 
    
         
            +
                  graph_module.graph.erase_node(node)
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
              graph_module.graph.lint()
         
     | 
| 
      
 41 
     | 
    
         
            +
              graph_module.recompile()
         
     | 
| 
      
 42 
     | 
    
         
            +
              return graph_module
         
     | 
| 
         @@ -80,31 +80,3 @@ def run_passes( 
     | 
|
| 
       80 
80 
     | 
    
         
             
                        constants=exported_program.constants,
         
     | 
| 
       81 
81 
     | 
    
         
             
                    )
         
     | 
| 
       82 
82 
     | 
    
         
             
              return exported_program
         
     | 
| 
       83 
     | 
    
         
            -
             
     | 
| 
       84 
     | 
    
         
            -
             
     | 
| 
       85 
     | 
    
         
            -
            class CanonicalizePass(ExportedProgramPassBase):
         
     | 
| 
       86 
     | 
    
         
            -
             
     | 
| 
       87 
     | 
    
         
            -
              # A dummy decomp table for running ExportedProgram.run_decompositions without
         
     | 
| 
       88 
     | 
    
         
            -
              # any op decompositions but just aot_export_module. Due to the check in
         
     | 
| 
       89 
     | 
    
         
            -
              # run_decompositions, if None or an empty dict is passed as decomp_table,
         
     | 
| 
       90 
     | 
    
         
            -
              # it will run the default aten-coreaten decompositions. Therefore a non-empty
         
     | 
| 
       91 
     | 
    
         
            -
              # dummy decomp table is needed.
         
     | 
| 
       92 
     | 
    
         
            -
              # Ref: https://github.com/pytorch/pytorch/blob/db895ace1d36726e64781774f53b3d3098206116/torch/export/exported_program.py#L543
         
     | 
| 
       93 
     | 
    
         
            -
              _DUMMY_DECOMP_TABLE = {
         
     | 
| 
       94 
     | 
    
         
            -
                  torch._ops.OperatorBase(): lambda: None,
         
     | 
| 
       95 
     | 
    
         
            -
              }
         
     | 
| 
       96 
     | 
    
         
            -
             
     | 
| 
       97 
     | 
    
         
            -
              def call(self, exported_program: torch.export.ExportedProgram):
         
     | 
| 
       98 
     | 
    
         
            -
                for node in exported_program.graph.nodes:
         
     | 
| 
       99 
     | 
    
         
            -
                  if node.target == torch.ops.aten.view.default:
         
     | 
| 
       100 
     | 
    
         
            -
                    # Passes or torch.export may generate aten.view nodes not respecting the
         
     | 
| 
       101 
     | 
    
         
            -
                    # tensor memory format. Changes all the aten.view to torch.reshape
         
     | 
| 
       102 
     | 
    
         
            -
                    # for retracing. If the input memory format is already contiguous,
         
     | 
| 
       103 
     | 
    
         
            -
                    # retracing in run_decomposition below would decompose torch.reshape
         
     | 
| 
       104 
     | 
    
         
            -
                    # back to one aten.view.
         
     | 
| 
       105 
     | 
    
         
            -
                    node.target = lambda self, size: torch.reshape(self.contiguous(), size)
         
     | 
| 
       106 
     | 
    
         
            -
             
     | 
| 
       107 
     | 
    
         
            -
                exported_program = exported_program.run_decompositions(
         
     | 
| 
       108 
     | 
    
         
            -
                    self._DUMMY_DECOMP_TABLE
         
     | 
| 
       109 
     | 
    
         
            -
                )
         
     | 
| 
       110 
     | 
    
         
            -
                return ExportedProgramPassResult(exported_program, True)
         
     | 
| 
         @@ -12,8 +12,8 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
       16 
     | 
    
         
            -
            from ai_edge_torch. 
     | 
| 
      
 15 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
      
 16 
     | 
    
         
            +
            from ai_edge_torch.fx_infra import CanonicalizePass
         
     | 
| 
       17 
17 
     | 
    
         
             
            from ai_edge_torch.generative.fx_passes.remove_sdpa_zero_mask_pass import RemoveSDPACompositeZeroMaskPass
         
     | 
| 
       18 
18 
     | 
    
         
             
            import torch
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
         @@ -21,7 +21,7 @@ import torch 
     | 
|
| 
       21 
21 
     | 
    
         
             
            def run_generative_passes(
         
     | 
| 
       22 
22 
     | 
    
         
             
                exported_program: torch.export.ExportedProgram,
         
     | 
| 
       23 
23 
     | 
    
         
             
            ) -> torch.export.ExportedProgram:
         
     | 
| 
       24 
     | 
    
         
            -
              return  
     | 
| 
      
 24 
     | 
    
         
            +
              return fx_infra.run_passes(
         
     | 
| 
       25 
25 
     | 
    
         
             
                  exported_program,
         
     | 
| 
       26 
26 
     | 
    
         
             
                  [
         
     | 
| 
       27 
27 
     | 
    
         
             
                      RemoveSDPACompositeZeroMaskPass(),
         
     | 
| 
         @@ -12,12 +12,12 @@ 
     | 
|
| 
       12 
12 
     | 
    
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
13 
     | 
    
         
             
            # limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 15 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       16 
16 
     | 
    
         
             
            from ai_edge_torch import lowertools
         
     | 
| 
       17 
17 
     | 
    
         
             
            import torch
         
     | 
| 
       18 
18 
     | 
    
         | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
     | 
    
         
            -
            class RemoveSDPACompositeZeroMaskPass( 
     | 
| 
      
 20 
     | 
    
         
            +
            class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
         
     | 
| 
       21 
21 
     | 
    
         | 
| 
       22 
22 
     | 
    
         
             
              def is_zero_tensor_node(self, node: torch.fx.Node):
         
     | 
| 
       23 
23 
     | 
    
         
             
                return node.target == torch.ops.aten.zeros.default
         
     | 
| 
         @@ -47,4 +47,4 @@ class RemoveSDPACompositeZeroMaskPass(fx_pass_base.ExportedProgramPassBase): 
     | 
|
| 
       47 
47 
     | 
    
         | 
| 
       48 
48 
     | 
    
         
             
                exported_program.graph_module.graph.lint()
         
     | 
| 
       49 
49 
     | 
    
         
             
                exported_program.graph_module.recompile()
         
     | 
| 
       50 
     | 
    
         
            -
                return  
     | 
| 
      
 50 
     | 
    
         
            +
                return fx_infra.ExportedProgramPassResult(exported_program, True)
         
     | 
| 
         @@ -14,8 +14,13 @@ 
     | 
|
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         
             
            """FX graph utilities for pattern matching clean ups."""
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
      
 17 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       17 
18 
     | 
    
         
             
            import torch
         
     | 
| 
       18 
19 
     | 
    
         | 
| 
      
 20 
     | 
    
         
            +
            remove_dangling_args = fx_infra.graph_utils.remove_dangling_args
         
     | 
| 
      
 21 
     | 
    
         
            +
            remove_assert_tensor_metadata_nodes = (
         
     | 
| 
      
 22 
     | 
    
         
            +
                fx_infra.graph_utils.remove_assert_tensor_metadata_nodes
         
     | 
| 
      
 23 
     | 
    
         
            +
            )
         
     | 
| 
       19 
24 
     | 
    
         | 
| 
       20 
25 
     | 
    
         
             
            def is_clone_op(node: torch.fx.Node) -> bool:
         
     | 
| 
       21 
26 
     | 
    
         
             
              """Checks if the node is a clone op."""
         
     | 
| 
         @@ -46,24 +51,3 @@ def remove_clone_ops(gm: torch.fx.GraphModule): 
     | 
|
| 
       46 
51 
     | 
    
         
             
              gm.graph.lint()
         
     | 
| 
       47 
52 
     | 
    
         
             
              gm.recompile()
         
     | 
| 
       48 
53 
     | 
    
         
             
              return gm
         
     | 
| 
       49 
     | 
    
         
            -
             
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
            def remove_dangling_args(gm: torch.fx.GraphModule):
         
     | 
| 
       52 
     | 
    
         
            -
              """Removes dangling args from the graph.
         
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
     | 
    
         
            -
              Args:
         
     | 
| 
       55 
     | 
    
         
            -
                gm: The graph module to remove dangling args from.
         
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
       57 
     | 
    
         
            -
              Returns:
         
     | 
| 
       58 
     | 
    
         
            -
                The graph module with dangling args removed.
         
     | 
| 
       59 
     | 
    
         
            -
              """
         
     | 
| 
       60 
     | 
    
         
            -
              nodes_to_erase = []
         
     | 
| 
       61 
     | 
    
         
            -
              for node in gm.graph.nodes:
         
     | 
| 
       62 
     | 
    
         
            -
                if node.op == "placeholder" and len(node.users) == 0:
         
     | 
| 
       63 
     | 
    
         
            -
                  nodes_to_erase.append(node)
         
     | 
| 
       64 
     | 
    
         
            -
              for node in nodes_to_erase:
         
     | 
| 
       65 
     | 
    
         
            -
                gm.graph.erase_node(node)
         
     | 
| 
       66 
     | 
    
         
            -
             
     | 
| 
       67 
     | 
    
         
            -
              gm.graph.lint()
         
     | 
| 
       68 
     | 
    
         
            -
              gm.recompile()
         
     | 
| 
       69 
     | 
    
         
            -
              return gm
         
     | 
| 
         @@ -17,7 +17,7 @@ 
     | 
|
| 
       17 
17 
     | 
    
         
             
            import dataclasses
         
     | 
| 
       18 
18 
     | 
    
         
             
            from typing import Any, Callable, Optional, Union
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
     | 
    
         
            -
            from ai_edge_torch import  
     | 
| 
      
 20 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       21 
21 
     | 
    
         
             
            from ai_edge_torch.hlfb.mark_pattern import fx_utils
         
     | 
| 
       22 
22 
     | 
    
         
             
            import torch
         
     | 
| 
       23 
23 
     | 
    
         | 
| 
         @@ -118,7 +118,7 @@ def _find_scalar_attr( 
     | 
|
| 
       118 
118 
     | 
    
         
             
                track_args[tracker.pattern_arg_pos] = source
         
     | 
| 
       119 
119 
     | 
    
         
             
                ep = torch.export.export(pattern_module, tuple(track_args))
         
     | 
| 
       120 
120 
     | 
    
         
             
                if decomp_table is not None:
         
     | 
| 
       121 
     | 
    
         
            -
                  ep =  
     | 
| 
      
 121 
     | 
    
         
            +
                  ep = fx_infra.run_passes(ep, [fx_infra.CanonicalizePass()])
         
     | 
| 
       122 
122 
     | 
    
         
             
                  ep = ep.run_decompositions(decomp_table)
         
     | 
| 
       123 
123 
     | 
    
         | 
| 
       124 
124 
     | 
    
         
             
                scalar_locs = set()
         
     | 
| 
         @@ -152,13 +152,15 @@ class Pattern: 
     | 
|
| 
       152 
152 
     | 
    
         
             
                  self,
         
     | 
| 
       153 
153 
     | 
    
         
             
                  name: str,
         
     | 
| 
       154 
154 
     | 
    
         
             
                  module: Union[Callable, torch.nn.Module],
         
     | 
| 
       155 
     | 
    
         
            -
                  export_args: tuple[Any],
         
     | 
| 
      
 155 
     | 
    
         
            +
                  export_args: tuple[Any, ...],
         
     | 
| 
       156 
156 
     | 
    
         
             
                  *,
         
     | 
| 
       157 
157 
     | 
    
         
             
                  attr_builder: Callable[
         
     | 
| 
       158 
158 
     | 
    
         
             
                      ["Pattern", GraphModule, InternalMatch], Optional[dict[str, Any]]
         
     | 
| 
       159 
159 
     | 
    
         
             
                  ] = None,
         
     | 
| 
       160 
160 
     | 
    
         
             
                  scalar_attr_trackers: list[ScalarAttrTracker] = None,
         
     | 
| 
       161 
     | 
    
         
            -
                   
     | 
| 
      
 161 
     | 
    
         
            +
                  extra_decomp_table: Optional[
         
     | 
| 
      
 162 
     | 
    
         
            +
                      dict[torch._ops.OperatorBase, Callable]
         
     | 
| 
      
 163 
     | 
    
         
            +
                  ] = None,
         
     | 
| 
       162 
164 
     | 
    
         
             
              ):
         
     | 
| 
       163 
165 
     | 
    
         
             
                """The PyTorch computation pattern to match against a model.
         
     | 
| 
       164 
166 
     | 
    
         | 
| 
         @@ -177,8 +179,9 @@ class Pattern: 
     | 
|
| 
       177 
179 
     | 
    
         
             
                  scalar_attr_trackers (list[ScalarAttrTracker]): the trackers for scalar
         
     | 
| 
       178 
180 
     | 
    
         
             
                    args in `export_args`, which are used to track the attr occurrence(s)
         
     | 
| 
       179 
181 
     | 
    
         
             
                    and retrieve their values from the matched subgraph.
         
     | 
| 
       180 
     | 
    
         
            -
                   
     | 
| 
       181 
     | 
    
         
            -
                    decomposition  
     | 
| 
      
 182 
     | 
    
         
            +
                  extra_decomp_table (Optional[dict[torch._ops.OperatorBase, Callable]]):
         
     | 
| 
      
 183 
     | 
    
         
            +
                    Extra decomposition to be run on the pattern's exported program (in
         
     | 
| 
      
 184 
     | 
    
         
            +
                    addition to the default pre_convert_decomp in fx_infra.decomp).
         
     | 
| 
       182 
185 
     | 
    
         
             
                """
         
     | 
| 
       183 
186 
     | 
    
         
             
                if not isinstance(module, torch.nn.Module):
         
     | 
| 
       184 
187 
     | 
    
         | 
| 
         @@ -200,11 +203,14 @@ class Pattern: 
     | 
|
| 
       200 
203 
     | 
    
         
             
                )
         
     | 
| 
       201 
204 
     | 
    
         | 
| 
       202 
205 
     | 
    
         
             
                exported_program = torch.export.export(module, export_args)
         
     | 
| 
       203 
     | 
    
         
            -
             
     | 
| 
       204 
     | 
    
         
            -
             
     | 
| 
       205 
     | 
    
         
            -
             
     | 
| 
       206 
     | 
    
         
            -
                  )
         
     | 
| 
       207 
     | 
    
         
            -
             
     | 
| 
      
 206 
     | 
    
         
            +
             
     | 
| 
      
 207 
     | 
    
         
            +
                decomp_table = fx_infra.decomp.pre_convert_decomp()
         
     | 
| 
      
 208 
     | 
    
         
            +
                if extra_decomp_table is not None:
         
     | 
| 
      
 209 
     | 
    
         
            +
                  decomp_table.update(extra_decomp_table)
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
                exported_program = fx_infra.safe_run_decompositions(
         
     | 
| 
      
 212 
     | 
    
         
            +
                    exported_program, decomp_table
         
     | 
| 
      
 213 
     | 
    
         
            +
                )
         
     | 
| 
       208 
214 
     | 
    
         | 
| 
       209 
215 
     | 
    
         
             
                self.exported_program = exported_program
         
     | 
| 
       210 
216 
     | 
    
         
             
                self.graph_module = self.exported_program.graph_module
         
     | 
| 
         @@ -222,6 +228,9 @@ class Pattern: 
     | 
|
| 
       222 
228 
     | 
    
         
             
                # sanitization.
         
     | 
| 
       223 
229 
     | 
    
         
             
                self.graph_module = fx_utils.remove_clone_ops(self.graph_module)
         
     | 
| 
       224 
230 
     | 
    
         
             
                self.graph_module = fx_utils.remove_dangling_args(self.graph_module)
         
     | 
| 
      
 231 
     | 
    
         
            +
                self.graph_module = fx_utils.remove_assert_tensor_metadata_nodes(
         
     | 
| 
      
 232 
     | 
    
         
            +
                    self.graph_module
         
     | 
| 
      
 233 
     | 
    
         
            +
                )
         
     | 
| 
       225 
234 
     | 
    
         | 
| 
       226 
235 
     | 
    
         
             
                # Builds list of ordered input and output nodes.
         
     | 
| 
       227 
236 
     | 
    
         
             
                self.graph_nodes_map = {}
         
     | 
| 
         @@ -14,6 +14,7 @@ 
     | 
|
| 
       14 
14 
     | 
    
         
             
            # ==============================================================================
         
     | 
| 
       15 
15 
     | 
    
         
             
            """Tests for mark_pattern."""
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
      
 17 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       17 
18 
     | 
    
         
             
            from ai_edge_torch import lowertools
         
     | 
| 
       18 
19 
     | 
    
         
             
            from ai_edge_torch.hlfb import mark_pattern
         
     | 
| 
       19 
20 
     | 
    
         
             
            from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
         
     | 
| 
         @@ -22,11 +23,13 @@ import torch 
     | 
|
| 
       22 
23 
     | 
    
         
             
            from absl.testing import absltest as googletest
         
     | 
| 
       23 
24 
     | 
    
         | 
| 
       24 
25 
     | 
    
         | 
| 
       25 
     | 
    
         
            -
            def  
     | 
| 
       26 
     | 
    
         
            -
               
     | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
       28 
     | 
    
         
            -
               
     | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
      
 26 
     | 
    
         
            +
            def _export_and_decomp(mod, args):
         
     | 
| 
      
 27 
     | 
    
         
            +
              ep = torch.export.export(mod, args)
         
     | 
| 
      
 28 
     | 
    
         
            +
              ep = fx_infra.safe_run_decompositions(ep, fx_infra.decomp.pre_lower_decomp())
         
     | 
| 
      
 29 
     | 
    
         
            +
              return ep
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            def _to_mlir(ep: torch.export.ExportedProgram):
         
     | 
| 
       30 
33 
     | 
    
         
             
              return lowertools.exported_program_to_mlir_text(ep)
         
     | 
| 
       31 
34 
     | 
    
         | 
| 
       32 
35 
     | 
    
         | 
| 
         @@ -47,9 +50,9 @@ class TestMarkPattern(googletest.TestCase): 
     | 
|
| 
       47 
50 
     | 
    
         | 
| 
       48 
51 
     | 
    
         
             
                model = TestModel().eval()
         
     | 
| 
       49 
52 
     | 
    
         
             
                args = (torch.rand(20, 20),)
         
     | 
| 
       50 
     | 
    
         
            -
                exported_program =  
     | 
| 
      
 53 
     | 
    
         
            +
                exported_program = _export_and_decomp(model, args)
         
     | 
| 
       51 
54 
     | 
    
         
             
                mark_pattern.mark_pattern(exported_program.graph_module, pattern)
         
     | 
| 
       52 
     | 
    
         
            -
                mlir =  
     | 
| 
      
 55 
     | 
    
         
            +
                mlir = _to_mlir(exported_program)
         
     | 
| 
       53 
56 
     | 
    
         | 
| 
       54 
57 
     | 
    
         
             
                lowertools.assert_string_count(
         
     | 
| 
       55 
58 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -73,9 +76,9 @@ class TestMarkPattern(googletest.TestCase): 
     | 
|
| 
       73 
76 
     | 
    
         | 
| 
       74 
77 
     | 
    
         
             
                model = TestModel().eval()
         
     | 
| 
       75 
78 
     | 
    
         
             
                args = (torch.rand(20, 20),)
         
     | 
| 
       76 
     | 
    
         
            -
                exported_program =  
     | 
| 
      
 79 
     | 
    
         
            +
                exported_program = _export_and_decomp(model, args)
         
     | 
| 
       77 
80 
     | 
    
         
             
                mark_pattern.mark_pattern(exported_program.graph_module, pattern)
         
     | 
| 
       78 
     | 
    
         
            -
                mlir =  
     | 
| 
      
 81 
     | 
    
         
            +
                mlir = _to_mlir(exported_program)
         
     | 
| 
       79 
82 
     | 
    
         | 
| 
       80 
83 
     | 
    
         
             
                lowertools.assert_string_count(
         
     | 
| 
       81 
84 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -99,9 +102,9 @@ class TestMarkPattern(googletest.TestCase): 
     | 
|
| 
       99 
102 
     | 
    
         | 
| 
       100 
103 
     | 
    
         
             
                model = TestModel().eval()
         
     | 
| 
       101 
104 
     | 
    
         
             
                args = (torch.rand(20, 20),)
         
     | 
| 
       102 
     | 
    
         
            -
                exported_program =  
     | 
| 
      
 105 
     | 
    
         
            +
                exported_program = _export_and_decomp(model, args)
         
     | 
| 
       103 
106 
     | 
    
         
             
                mark_pattern.mark_pattern(exported_program.graph_module, pattern)
         
     | 
| 
       104 
     | 
    
         
            -
                mlir =  
     | 
| 
      
 107 
     | 
    
         
            +
                mlir = _to_mlir(exported_program)
         
     | 
| 
       105 
108 
     | 
    
         | 
| 
       106 
109 
     | 
    
         
             
                lowertools.assert_string_count(
         
     | 
| 
       107 
110 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -137,9 +140,9 @@ class TestMarkPattern(googletest.TestCase): 
     | 
|
| 
       137 
140 
     | 
    
         | 
| 
       138 
141 
     | 
    
         
             
                model = TestModel().eval()
         
     | 
| 
       139 
142 
     | 
    
         
             
                args = (torch.rand(10, 10),)
         
     | 
| 
       140 
     | 
    
         
            -
                exported_program =  
     | 
| 
      
 143 
     | 
    
         
            +
                exported_program = _export_and_decomp(model, args)
         
     | 
| 
       141 
144 
     | 
    
         
             
                mark_pattern.mark_pattern(exported_program.graph_module, pattern)
         
     | 
| 
       142 
     | 
    
         
            -
                mlir =  
     | 
| 
      
 145 
     | 
    
         
            +
                mlir = _to_mlir(exported_program)
         
     | 
| 
       143 
146 
     | 
    
         | 
| 
       144 
147 
     | 
    
         
             
                lowertools.assert_string_count(
         
     | 
| 
       145 
148 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -169,9 +172,9 @@ class TestMarkPattern(googletest.TestCase): 
     | 
|
| 
       169 
172 
     | 
    
         | 
| 
       170 
173 
     | 
    
         
             
                model = TestModel().eval()
         
     | 
| 
       171 
174 
     | 
    
         
             
                args = (torch.rand(20, 20), torch.rand(20, 20))
         
     | 
| 
       172 
     | 
    
         
            -
                exported_program =  
     | 
| 
      
 175 
     | 
    
         
            +
                exported_program = _export_and_decomp(model, args)
         
     | 
| 
       173 
176 
     | 
    
         
             
                mark_pattern.mark_pattern(exported_program.graph_module, pattern)
         
     | 
| 
       174 
     | 
    
         
            -
                mlir =  
     | 
| 
      
 177 
     | 
    
         
            +
                mlir = _to_mlir(exported_program)
         
     | 
| 
       175 
178 
     | 
    
         | 
| 
       176 
179 
     | 
    
         
             
                lowertools.assert_string_count(
         
     | 
| 
       177 
180 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -59,30 +59,3 @@ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs): 
     | 
|
| 
       59 
59 
     | 
    
         
             
                ordered_tensor_constants = tuple()
         
     | 
| 
       60 
60 
     | 
    
         | 
| 
       61 
61 
     | 
    
         
             
              return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
         
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
       64 
     | 
    
         
            -
            # TODO(b/331481564): Replace this with CanonicalizePass + run_decomposition
         
     | 
| 
       65 
     | 
    
         
            -
            def safe_run_decompositions(exported_program, decomp_table=None):
         
     | 
| 
       66 
     | 
    
         
            -
              for node in exported_program.graph.nodes:
         
     | 
| 
       67 
     | 
    
         
            -
                if node.target == torch.ops.aten.view.default:
         
     | 
| 
       68 
     | 
    
         
            -
                  # Passes or torch.export may generate aten.view nodes not respecting the
         
     | 
| 
       69 
     | 
    
         
            -
                  # tensor memory format. Changes all the aten.view to torch.reshape
         
     | 
| 
       70 
     | 
    
         
            -
                  # for retracing. If the input memory format is already contiguous,
         
     | 
| 
       71 
     | 
    
         
            -
                  # retracing in run_decomposition below would decompose torch.reshape
         
     | 
| 
       72 
     | 
    
         
            -
                  # back to one aten.view.
         
     | 
| 
       73 
     | 
    
         
            -
                  node.target = lambda self, size: torch.reshape(self.contiguous(), size)
         
     | 
| 
       74 
     | 
    
         
            -
             
     | 
| 
       75 
     | 
    
         
            -
              return exported_program.run_decompositions(decomp_table)
         
     | 
| 
       76 
     | 
    
         
            -
             
     | 
| 
       77 
     | 
    
         
            -
             
     | 
| 
       78 
     | 
    
         
            -
            def dummy_decomp_table():
         
     | 
| 
       79 
     | 
    
         
            -
              """Build dummy decomp table for run_decompositions without any decompositions.
         
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
     | 
    
         
            -
              Compatible for torch<=2.5.
         
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
              Returns:
         
     | 
| 
       84 
     | 
    
         
            -
                Decomp table for ExportedProgram.run_decompositions.
         
     | 
| 
       85 
     | 
    
         
            -
              """
         
     | 
| 
       86 
     | 
    
         
            -
              return {
         
     | 
| 
       87 
     | 
    
         
            -
                  torch._ops.OperatorBase(): lambda: None,
         
     | 
| 
       88 
     | 
    
         
            -
              }
         
     | 
| 
         @@ -20,6 +20,7 @@ import io 
     | 
|
| 
       20 
20 
     | 
    
         
             
            import operator
         
     | 
| 
       21 
21 
     | 
    
         
             
            from typing import Any, Callable, Optional
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
      
 23 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
       23 
24 
     | 
    
         
             
            from jax.lib import xla_extension
         
     | 
| 
       24 
25 
     | 
    
         
             
            from jax._src.lib.mlir import ir
         
     | 
| 
       25 
26 
     | 
    
         
             
            from jax._src.lib.mlir.dialects import func
         
     | 
| 
         @@ -302,16 +303,13 @@ def exported_program_to_mlir( 
     | 
|
| 
       302 
303 
     | 
    
         
             
                exported_program: torch.export.ExportedProgram,
         
     | 
| 
       303 
304 
     | 
    
         
             
            ) -> MlirLowered:
         
     | 
| 
       304 
305 
     | 
    
         
             
              """Lower the exported program to MLIR."""
         
     | 
| 
       305 
     | 
    
         
            -
              exported_program =  
     | 
| 
       306 
     | 
    
         
            -
                  exported_program, 
     | 
| 
      
 306 
     | 
    
         
            +
              exported_program = fx_infra.safe_run_decompositions(
         
     | 
| 
      
 307 
     | 
    
         
            +
                  exported_program,
         
     | 
| 
      
 308 
     | 
    
         
            +
                  fx_infra.decomp.pre_lower_decomp(),
         
     | 
| 
       307 
309 
     | 
    
         
             
              )
         
     | 
| 
       308 
     | 
    
         
            -
             
     | 
| 
       309 
310 
     | 
    
         
             
              _convert_i64_to_i32(exported_program)
         
     | 
| 
       310 
     | 
    
         
            -
             
     | 
| 
       311 
     | 
    
         
            -
               
     | 
| 
       312 
     | 
    
         
            -
              exported_program = _torch_future.safe_run_decompositions(
         
     | 
| 
       313 
     | 
    
         
            -
                  exported_program, _torch_future.dummy_decomp_table()
         
     | 
| 
       314 
     | 
    
         
            -
              )
         
     | 
| 
      
 311 
     | 
    
         
            +
              # Run decompositions for retracing and cananicalization.
         
     | 
| 
      
 312 
     | 
    
         
            +
              exported_program = fx_infra.safe_run_decompositions(exported_program, {})
         
     | 
| 
       315 
313 
     | 
    
         | 
| 
       316 
314 
     | 
    
         
             
              # Passes below mutate the exported program to a state not executable by torch.
         
     | 
| 
       317 
315 
     | 
    
         
             
              # Do not call run_decompositions after applying the passes.
         
     | 
| 
         @@ -15,6 +15,7 @@ 
     | 
|
| 
       15 
15 
     | 
    
         
             
            from . import _basic
         
     | 
| 
       16 
16 
     | 
    
         
             
            from . import _batch_norm
         
     | 
| 
       17 
17 
     | 
    
         
             
            from . import _convolution
         
     | 
| 
      
 18 
     | 
    
         
            +
            from . import _decomp_registry
         
     | 
| 
       18 
19 
     | 
    
         
             
            from . import _jax_lowerings
         
     | 
| 
       19 
20 
     | 
    
         
             
            from . import _layer_norm
         
     | 
| 
       20 
21 
     | 
    
         
             
            from . import _quantized_decomposed
         
     | 
| 
         @@ -22,6 +23,5 @@ from . import _rand 
     | 
|
| 
       22 
23 
     | 
    
         
             
            from . import context
         
     | 
| 
       23 
24 
     | 
    
         
             
            from . import registry
         
     | 
| 
       24 
25 
     | 
    
         
             
            from . import utils
         
     | 
| 
       25 
     | 
    
         
            -
            from .decomp import decompositions
         
     | 
| 
       26 
26 
     | 
    
         
             
            from .registry import lookup
         
     | 
| 
       27 
27 
     | 
    
         
             
            from .registry import lower
         
     | 
| 
         @@ -0,0 +1,65 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
      
 2 
     | 
    
         
            +
            #
         
     | 
| 
      
 3 
     | 
    
         
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            # you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            # You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
            #
         
     | 
| 
      
 7 
     | 
    
         
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
            #
         
     | 
| 
      
 9 
     | 
    
         
            +
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            # See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            # limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            # ==============================================================================
         
     | 
| 
      
 15 
     | 
    
         
            +
            """Torch export decompositions to run before lowering."""
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            from ai_edge_torch import fx_infra
         
     | 
| 
      
 18 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            fx_infra.decomp.update_pre_lower_decomp(
         
     | 
| 
      
 22 
     | 
    
         
            +
                torch._decomp.get_decompositions([
         
     | 
| 
      
 23 
     | 
    
         
            +
                    torch.ops.aten.upsample_nearest2d,
         
     | 
| 
      
 24 
     | 
    
         
            +
                    torch.ops.aten._native_batch_norm_legit.no_stats,
         
     | 
| 
      
 25 
     | 
    
         
            +
                    torch.ops.aten._native_batch_norm_legit_functional,
         
     | 
| 
      
 26 
     | 
    
         
            +
                    torch.ops.aten._adaptive_avg_pool2d,
         
     | 
| 
      
 27 
     | 
    
         
            +
                    torch.ops.aten._adaptive_avg_pool3d,
         
     | 
| 
      
 28 
     | 
    
         
            +
                    torch.ops.aten.grid_sampler_2d,
         
     | 
| 
      
 29 
     | 
    
         
            +
                    torch.ops.aten.native_group_norm,
         
     | 
| 
      
 30 
     | 
    
         
            +
                    torch.ops.aten.native_dropout,
         
     | 
| 
      
 31 
     | 
    
         
            +
                    torch.ops.aten.reflection_pad1d,
         
     | 
| 
      
 32 
     | 
    
         
            +
                    torch.ops.aten.reflection_pad2d,
         
     | 
| 
      
 33 
     | 
    
         
            +
                    torch.ops.aten.reflection_pad3d,
         
     | 
| 
      
 34 
     | 
    
         
            +
                    torch.ops.aten.replication_pad1d,
         
     | 
| 
      
 35 
     | 
    
         
            +
                    torch.ops.aten.replication_pad2d,
         
     | 
| 
      
 36 
     | 
    
         
            +
                    torch.ops.aten.replication_pad3d,
         
     | 
| 
      
 37 
     | 
    
         
            +
                    torch.ops.aten.addmm,
         
     | 
| 
      
 38 
     | 
    
         
            +
                ])
         
     | 
| 
      
 39 
     | 
    
         
            +
            )
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
            fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
            # Torch's default einsum impl/decompositions is less efficient and
         
     | 
| 
      
 44 
     | 
    
         
            +
            # optimized through converter than JAX's impl. Disable einsum
         
     | 
| 
      
 45 
     | 
    
         
            +
            # decomposition to use JAX bridge for a more efficient lowering.
         
     | 
| 
      
 46 
     | 
    
         
            +
            fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.einsum.default)
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
            # Override noop aten op decompositions for faster run_decompositions.
         
     | 
| 
      
 50 
     | 
    
         
            +
            fx_infra.decomp.add_pre_convert_decomp(
         
     | 
| 
      
 51 
     | 
    
         
            +
                torch.ops.aten.alias.default, lambda x: x
         
     | 
| 
      
 52 
     | 
    
         
            +
            )
         
     | 
| 
      
 53 
     | 
    
         
            +
            fx_infra.decomp.add_pre_convert_decomp(
         
     | 
| 
      
 54 
     | 
    
         
            +
                torch.ops.aten.detach.default, lambda x: x
         
     | 
| 
      
 55 
     | 
    
         
            +
            )
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
            # Override _safe_softmax decompositions with regular softmax.
         
     | 
| 
      
 58 
     | 
    
         
            +
            # _safe_softmax introduces additional check-select ops to guard extreme
         
     | 
| 
      
 59 
     | 
    
         
            +
            # input values to softmax, which could make the converted model inefficient
         
     | 
| 
      
 60 
     | 
    
         
            +
            # on-device.
         
     | 
| 
      
 61 
     | 
    
         
            +
            if hasattr(torch.ops.aten, "_safe_softmax"):
         
     | 
| 
      
 62 
     | 
    
         
            +
              fx_infra.decomp.add_pre_convert_decomp(
         
     | 
| 
      
 63 
     | 
    
         
            +
                  torch.ops.aten._safe_softmax.default,
         
     | 
| 
      
 64 
     | 
    
         
            +
                  torch.softmax,
         
     | 
| 
      
 65 
     | 
    
         
            +
              )
         
     | 
    
        ai_edge_torch/version.py
    CHANGED
    
    
| 
         @@ -1,6 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            Metadata-Version: 2.1
         
     | 
| 
       2 
2 
     | 
    
         
             
            Name: ai-edge-torch-nightly
         
     | 
| 
       3 
     | 
    
         
            -
            Version: 0.3.0. 
     | 
| 
      
 3 
     | 
    
         
            +
            Version: 0.3.0.dev20250123
         
     | 
| 
       4 
4 
     | 
    
         
             
            Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
         
     | 
| 
       5 
5 
     | 
    
         
             
            Home-page: https://github.com/google-ai-edge/ai-edge-torch
         
     | 
| 
       6 
6 
     | 
    
         
             
            Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
         
     | 
| 
         @@ -1,32 +1,32 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            ai_edge_torch/__init__.py,sha256= 
     | 
| 
      
 1 
     | 
    
         
            +
            ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
         
     | 
| 
       2 
2 
     | 
    
         
             
            ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
         
     | 
| 
       3 
3 
     | 
    
         
             
            ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
         
     | 
| 
       4 
     | 
    
         
            -
            ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
         
     | 
| 
       5 
4 
     | 
    
         
             
            ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
         
     | 
| 
       6 
     | 
    
         
            -
            ai_edge_torch/version.py,sha256= 
     | 
| 
      
 5 
     | 
    
         
            +
            ai_edge_torch/version.py,sha256=szrxg2aB7mcm59IL_QVIqapmbw9Nz8AQ28vc9684bqY,706
         
     | 
| 
       7 
6 
     | 
    
         
             
            ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       8 
     | 
    
         
            -
            ai_edge_torch/_convert/conversion.py,sha256= 
     | 
| 
      
 7 
     | 
    
         
            +
            ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
         
     | 
| 
       9 
8 
     | 
    
         
             
            ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
         
     | 
| 
       10 
9 
     | 
    
         
             
            ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
         
     | 
| 
       11 
10 
     | 
    
         
             
            ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
         
     | 
| 
       12 
11 
     | 
    
         
             
            ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
         
     | 
| 
       13 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/__init__.py,sha256= 
     | 
| 
       14 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256= 
     | 
| 
       15 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256= 
     | 
| 
       16 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256= 
     | 
| 
       17 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256= 
     | 
| 
       18 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256= 
     | 
| 
      
 12 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/__init__.py,sha256=dG4WIICk0FqCH9euvbYHHsybRN7B1cYcuxN_OYxmjWo,1263
         
     | 
| 
      
 13 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=a1KhqLetFb_efRHjX4T-zH0vF-U37Ha5I1CPIAsIluE,9211
         
     | 
| 
      
 14 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=3JyjiHpn17Zhfq3yGQXK5LMH71DQPXHb_4GOkP9uAjY,4251
         
     | 
| 
      
 15 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZvMl3Ivpqa3burVOLKFndEZuNmWKNxjq2mM,2386
         
     | 
| 
      
 16 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
         
     | 
| 
      
 17 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
         
     | 
| 
      
 18 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
         
     | 
| 
       19 
19 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
         
     | 
| 
       20 
20 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
         
     | 
| 
       21 
21 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
         
     | 
| 
       22 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256= 
     | 
| 
       23 
     | 
    
         
            -
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256= 
     | 
| 
      
 22 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
         
     | 
| 
      
 23 
     | 
    
         
            +
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=lgoH32l6zAbWTCpa_4-RWkHjqbNaPsBnhSObLIX8dL4,10551
         
     | 
| 
       24 
24 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
         
     | 
| 
       25 
25 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
         
     | 
| 
       26 
26 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
         
     | 
| 
       27 
27 
     | 
    
         
             
            ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
         
     | 
| 
       28 
28 
     | 
    
         
             
            ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       29 
     | 
    
         
            -
            ai_edge_torch/_convert/test/test_convert.py,sha256= 
     | 
| 
      
 29 
     | 
    
         
            +
            ai_edge_torch/_convert/test/test_convert.py,sha256=o6tuJkD-ESaQxLxJpN104qpchm3LCtPmHinzQxe6PSg,17226
         
     | 
| 
       30 
30 
     | 
    
         
             
            ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
         
     | 
| 
       31 
31 
     | 
    
         
             
            ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
         
     | 
| 
       32 
32 
     | 
    
         
             
            ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
         
     | 
| 
         @@ -37,6 +37,12 @@ ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf 
     | 
|
| 
       37 
37 
     | 
    
         
             
            ai_edge_torch/debug/test/test_culprit.py,sha256=fRN-8jJicawJ2mhPRQNAQUZ8AdGg-s0tYMXyhnLAlWw,3875
         
     | 
| 
       38 
38 
     | 
    
         
             
            ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOawhbgiSEu96PmioPE,1668
         
     | 
| 
       39 
39 
     | 
    
         
             
            ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
      
 40 
     | 
    
         
            +
            ai_edge_torch/fx_infra/__init__.py,sha256=APjkSqEfwDxcnI8k53rGi3Ef-G2L-M8fdaPGpxXtuiI,1347
         
     | 
| 
      
 41 
     | 
    
         
            +
            ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
         
     | 
| 
      
 42 
     | 
    
         
            +
            ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=ZbWheeZ8ydsxCk2aVGUgUynrkEkBOMjBCzPhS5uq4sU,2595
         
     | 
| 
      
 43 
     | 
    
         
            +
            ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
         
     | 
| 
      
 44 
     | 
    
         
            +
            ai_edge_torch/fx_infra/graph_utils.py,sha256=3UZAOHWOUh2LCj1E2_AKQn3gRDILi9JCdqSScjyOd4M,1535
         
     | 
| 
      
 45 
     | 
    
         
            +
            ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
         
     | 
| 
       40 
46 
     | 
    
         
             
            ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       41 
47 
     | 
    
         
             
            ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       42 
48 
     | 
    
         
             
            ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
         @@ -115,8 +121,8 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6 
     | 
|
| 
       115 
121 
     | 
    
         
             
            ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=VU0c5pgvrUtaTboT1xuDBGjpKOM85aqtaB_hYfSBuEk,2544
         
     | 
| 
       116 
122 
     | 
    
         
             
            ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
         
     | 
| 
       117 
123 
     | 
    
         
             
            ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
         
     | 
| 
       118 
     | 
    
         
            -
            ai_edge_torch/generative/fx_passes/__init__.py,sha256= 
     | 
| 
       119 
     | 
    
         
            -
            ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256= 
     | 
| 
      
 124 
     | 
    
         
            +
            ai_edge_torch/generative/fx_passes/__init__.py,sha256=4rFrppMRKlTwwZeX1ON_cdp4yUqoTOES161IZQkJF6c,1143
         
     | 
| 
      
 125 
     | 
    
         
            +
            ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=n5TbXdhBZi8jQe4j7-rox_MugMVvW8ReOhkTA3pfQkw,1919
         
     | 
| 
       120 
126 
     | 
    
         
             
            ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       121 
127 
     | 
    
         
             
            ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8kQ4U3YANfSiTJKn8,13776
         
     | 
| 
       122 
128 
     | 
    
         
             
            ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
         
     | 
| 
         @@ -159,11 +165,11 @@ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yG 
     | 
|
| 
       159 
165 
     | 
    
         
             
            ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
         
     | 
| 
       160 
166 
     | 
    
         
             
            ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
         
     | 
| 
       161 
167 
     | 
    
         
             
            ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
         
     | 
| 
       162 
     | 
    
         
            -
            ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256 
     | 
| 
       163 
     | 
    
         
            -
            ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256= 
     | 
| 
       164 
     | 
    
         
            -
            ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256= 
     | 
| 
      
 168 
     | 
    
         
            +
            ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=JsVmYrM_JEuN_smMHXUsRlo3Liapp7UyktbPpPARwDk,5386
         
     | 
| 
      
 169 
     | 
    
         
            +
            ai_edge_torch/hlfb/mark_pattern/fx_utils.py,sha256=YCtMgu-4w2BQ5fpnlpWC6IauKPf_tVqc7Ff91OTqlSw,1796
         
     | 
| 
      
 170 
     | 
    
         
            +
            ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=Ui6BrehF3zJJN7uTxKwbO2yCY9mYjbewlQzAxzZv9Es,10274
         
     | 
| 
       165 
171 
     | 
    
         
             
            ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       166 
     | 
    
         
            -
            ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256 
     | 
| 
      
 172 
     | 
    
         
            +
            ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=5kmOJWCc7sU1Hrqr1y17BtShUrssTfaV1sMyUvdMbsg,5573
         
     | 
| 
       167 
173 
     | 
    
         
             
            ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
         
     | 
| 
       168 
174 
     | 
    
         
             
            ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5SIrG0,3276
         
     | 
| 
       169 
175 
     | 
    
         
             
            ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5zo3MlX1QYW4c,4513
         
     | 
| 
         @@ -172,9 +178,9 @@ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUG 
     | 
|
| 
       172 
178 
     | 
    
         
             
            ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
         
     | 
| 
       173 
179 
     | 
    
         
             
            ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
         
     | 
| 
       174 
180 
     | 
    
         
             
            ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
         
     | 
| 
       175 
     | 
    
         
            -
            ai_edge_torch/odml_torch/_torch_future.py,sha256= 
     | 
| 
      
 181 
     | 
    
         
            +
            ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
         
     | 
| 
       176 
182 
     | 
    
         
             
            ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
         
     | 
| 
       177 
     | 
    
         
            -
            ai_edge_torch/odml_torch/export.py,sha256= 
     | 
| 
      
 183 
     | 
    
         
            +
            ai_edge_torch/odml_torch/export.py,sha256=YN7QPrQ8W6T3YVOdyIGadfSQuBroMjIqAMB9FeUa7Ho,13447
         
     | 
| 
       178 
184 
     | 
    
         
             
            ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
         
     | 
| 
       179 
185 
     | 
    
         
             
            ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
         
     | 
| 
       180 
186 
     | 
    
         
             
            ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
         
     | 
| 
         @@ -186,16 +192,16 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW 
     | 
|
| 
       186 
192 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
         
     | 
| 
       187 
193 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
         
     | 
| 
       188 
194 
     | 
    
         
             
            ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
         
     | 
| 
       189 
     | 
    
         
            -
            ai_edge_torch/odml_torch/lowerings/__init__.py,sha256= 
     | 
| 
      
 195 
     | 
    
         
            +
            ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
         
     | 
| 
       190 
196 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
         
     | 
| 
       191 
197 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
         
     | 
| 
       192 
198 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
         
     | 
| 
      
 199 
     | 
    
         
            +
            ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
         
     | 
| 
       193 
200 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=CJHWkmY4aAVQ5dmFsVc3Ox9TPkoLSNOfa96psD4CLRo,11561
         
     | 
| 
       194 
201 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
         
     | 
| 
       195 
202 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
         
     | 
| 
       196 
203 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
         
     | 
| 
       197 
204 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
         
     | 
| 
       198 
     | 
    
         
            -
            ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=0dIkVNN9ubgfwSxeaBFfTmWarWWG8Q1M0HX_1FShHik,2610
         
     | 
| 
       199 
205 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
         
     | 
| 
       200 
206 
     | 
    
         
             
            ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZmWcAHl0nTDgyI2g,6167
         
     | 
| 
       201 
207 
     | 
    
         
             
            ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
         
     | 
| 
         @@ -206,8 +212,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9 
     | 
|
| 
       206 
212 
     | 
    
         
             
            ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
         
     | 
| 
       207 
213 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
         
     | 
| 
       208 
214 
     | 
    
         
             
            ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
         
     | 
| 
       209 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       210 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       211 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       212 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
       213 
     | 
    
         
            -
            ai_edge_torch_nightly-0.3.0. 
     | 
| 
      
 215 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
         
     | 
| 
      
 216 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/METADATA,sha256=1IZCBOcKVCWbEfAQvEMgt39cuATDIzpK6AhW_gTnIY4,1966
         
     | 
| 
      
 217 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
         
     | 
| 
      
 218 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
         
     | 
| 
      
 219 
     | 
    
         
            +
            ai_edge_torch_nightly-0.3.0.dev20250123.dist-info/RECORD,,
         
     | 
| 
         @@ -1,69 +0,0 @@ 
     | 
|
| 
       1 
     | 
    
         
            -
            # Copyright 2024 The AI Edge Torch Authors.
         
     | 
| 
       2 
     | 
    
         
            -
            #
         
     | 
| 
       3 
     | 
    
         
            -
            # Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
       4 
     | 
    
         
            -
            # you may not use this file except in compliance with the License.
         
     | 
| 
       5 
     | 
    
         
            -
            # You may obtain a copy of the License at
         
     | 
| 
       6 
     | 
    
         
            -
            #
         
     | 
| 
       7 
     | 
    
         
            -
            #     http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
       8 
     | 
    
         
            -
            #
         
     | 
| 
       9 
     | 
    
         
            -
            # Unless required by applicable law or agreed to in writing, software
         
     | 
| 
       10 
     | 
    
         
            -
            # distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
       11 
     | 
    
         
            -
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
       12 
     | 
    
         
            -
            # See the License for the specific language governing permissions and
         
     | 
| 
       13 
     | 
    
         
            -
            # limitations under the License.
         
     | 
| 
       14 
     | 
    
         
            -
            # ==============================================================================
         
     | 
| 
       15 
     | 
    
         
            -
            """Torch export decompositions to run before lowering."""
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
            import functools
         
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
            import torch
         
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
            @functools.cache
         
     | 
| 
       23 
     | 
    
         
            -
            def decompositions():
         
     | 
| 
       24 
     | 
    
         
            -
              # Base: Core ATen decompositions
         
     | 
| 
       25 
     | 
    
         
            -
              decompositions = torch._decomp.core_aten_decompositions()
         
     | 
| 
       26 
     | 
    
         
            -
             
     | 
| 
       27 
     | 
    
         
            -
              decompositions.update(
         
     | 
| 
       28 
     | 
    
         
            -
                  torch._decomp.get_decompositions([
         
     | 
| 
       29 
     | 
    
         
            -
                      torch.ops.aten.upsample_nearest2d,
         
     | 
| 
       30 
     | 
    
         
            -
                      torch.ops.aten._native_batch_norm_legit.no_stats,
         
     | 
| 
       31 
     | 
    
         
            -
                      torch.ops.aten._native_batch_norm_legit_functional,
         
     | 
| 
       32 
     | 
    
         
            -
                      torch.ops.aten._adaptive_avg_pool2d,
         
     | 
| 
       33 
     | 
    
         
            -
                      torch.ops.aten._adaptive_avg_pool3d,
         
     | 
| 
       34 
     | 
    
         
            -
                      torch.ops.aten.grid_sampler_2d,
         
     | 
| 
       35 
     | 
    
         
            -
                      torch.ops.aten.native_group_norm,
         
     | 
| 
       36 
     | 
    
         
            -
                      torch.ops.aten.native_dropout,
         
     | 
| 
       37 
     | 
    
         
            -
                      torch.ops.aten.reflection_pad1d,
         
     | 
| 
       38 
     | 
    
         
            -
                      torch.ops.aten.reflection_pad2d,
         
     | 
| 
       39 
     | 
    
         
            -
                      torch.ops.aten.reflection_pad3d,
         
     | 
| 
       40 
     | 
    
         
            -
                      torch.ops.aten.replication_pad1d,
         
     | 
| 
       41 
     | 
    
         
            -
                      torch.ops.aten.replication_pad2d,
         
     | 
| 
       42 
     | 
    
         
            -
                      torch.ops.aten.replication_pad3d,
         
     | 
| 
       43 
     | 
    
         
            -
                      torch.ops.aten.addmm,
         
     | 
| 
       44 
     | 
    
         
            -
                  ])
         
     | 
| 
       45 
     | 
    
         
            -
              )
         
     | 
| 
       46 
     | 
    
         
            -
             
     | 
| 
       47 
     | 
    
         
            -
              torch._decomp.remove_decompositions(
         
     | 
| 
       48 
     | 
    
         
            -
                  decompositions,
         
     | 
| 
       49 
     | 
    
         
            -
                  [
         
     | 
| 
       50 
     | 
    
         
            -
                      torch.ops.aten.roll,
         
     | 
| 
       51 
     | 
    
         
            -
                      # Torch's default einsum impl/decompositions is less efficient and
         
     | 
| 
       52 
     | 
    
         
            -
                      # optimized through converter than JAX's impl. Disable einsum
         
     | 
| 
       53 
     | 
    
         
            -
                      # decomposition to use JAX bridge for a more efficient lowering.
         
     | 
| 
       54 
     | 
    
         
            -
                      torch.ops.aten.einsum.default,
         
     | 
| 
       55 
     | 
    
         
            -
                  ],
         
     | 
| 
       56 
     | 
    
         
            -
              )
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
              # Override noop aten op decompositions for faster run_decompositions.
         
     | 
| 
       59 
     | 
    
         
            -
              decompositions[torch.ops.aten.alias.default] = lambda x: x
         
     | 
| 
       60 
     | 
    
         
            -
              decompositions[torch.ops.aten.detach.default] = lambda x: x
         
     | 
| 
       61 
     | 
    
         
            -
             
     | 
| 
       62 
     | 
    
         
            -
              # Override _safe_softmax decompositions with regular softmax.
         
     | 
| 
       63 
     | 
    
         
            -
              # _safe_softmax introduces additional check-select ops to guard extreme
         
     | 
| 
       64 
     | 
    
         
            -
              # input values to softmax, which could make the converted model inefficient
         
     | 
| 
       65 
     | 
    
         
            -
              # on-device.
         
     | 
| 
       66 
     | 
    
         
            -
              if hasattr(torch.ops.aten, "_safe_softmax"):
         
     | 
| 
       67 
     | 
    
         
            -
                decompositions[torch.ops.aten._safe_softmax.default] = torch.softmax
         
     | 
| 
       68 
     | 
    
         
            -
             
     | 
| 
       69 
     | 
    
         
            -
              return decompositions
         
     | 
| 
         
            File without changes
         
     | 
| 
         
            File without changes
         
     |