ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py CHANGED
@@ -13,11 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from .convert.converter import convert
17
- from .convert.converter import signature
18
- from .convert.to_channel_last_io import to_channel_last_io
19
- from .model import Model
20
- from .version import __version__
16
+ from ai_edge_torch._convert.converter import convert
17
+ from ai_edge_torch._convert.converter import signature
18
+ from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
+ from ai_edge_torch.model import Model
20
+ from ai_edge_torch.version import __version__
21
21
 
22
22
 
23
23
  def load(path: str) -> Model:
@@ -13,48 +13,44 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import gc
17
16
  import logging
18
17
  import os
19
- from typing import Optional
18
+ from typing import Any, Optional
20
19
 
20
+ from ai_edge_torch import lowertools
21
21
  from ai_edge_torch import model
22
- from ai_edge_torch.convert import conversion_utils as cutils
23
- from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
24
- from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
25
- from ai_edge_torch.convert.fx_passes import CanonicalizePass
26
- from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
27
- from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
28
- from ai_edge_torch.convert.fx_passes import run_passes
29
- from ai_edge_torch.generative.fx_passes import run_generative_passes
22
+ from ai_edge_torch._convert import fx_passes
23
+ from ai_edge_torch._convert import signature
24
+ from ai_edge_torch.generative import fx_passes as generative_fx_passes
30
25
  from ai_edge_torch.quantize import quant_config as qcfg
31
26
  import torch
32
- from torch.export import ExportedProgram
33
- from torch_xla import stablehlo
34
27
 
35
28
  os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
36
29
 
37
30
 
38
31
  def _run_convert_passes(
39
- exported_program: ExportedProgram,
40
- ) -> ExportedProgram:
41
- exported_program = run_generative_passes(exported_program)
42
- return run_passes(
32
+ exported_program: torch.export.ExportedProgram,
33
+ ) -> torch.export.ExportedProgram:
34
+ exported_program = generative_fx_passes.run_generative_passes(
35
+ exported_program
36
+ )
37
+ return fx_passes.run_passes(
43
38
  exported_program,
44
39
  [
45
- BuildInterpolateCompositePass(),
46
- CanonicalizePass(),
47
- OptimizeLayoutTransposesPass(),
48
- CanonicalizePass(),
49
- BuildAtenCompositePass(),
50
- CanonicalizePass(),
51
- InjectMlirDebuginfoPass(),
52
- CanonicalizePass(),
40
+ fx_passes.BuildInterpolateCompositePass(),
41
+ fx_passes.CanonicalizePass(),
42
+ fx_passes.OptimizeLayoutTransposesPass(),
43
+ fx_passes.CanonicalizePass(),
44
+ fx_passes.BuildAtenCompositePass(),
45
+ fx_passes.CanonicalizePass(),
46
+ fx_passes.InjectMlirDebuginfoPass(),
47
+ fx_passes.CanonicalizePass(),
53
48
  ],
54
49
  )
55
50
 
56
51
 
57
- def _warn_training_modules(signatures: list[cutils.Signature]):
52
+ def _warn_training_modules(signatures: list[signature.Signature]):
53
+ """Warns the user if the module is in training mode (.eval not called)."""
58
54
  for sig in signatures:
59
55
  if not sig.module.training:
60
56
  continue
@@ -64,30 +60,39 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
64
60
  " module in evaluation mode with `module.eval()` for better on-device"
65
61
  " performance and compatibility."
66
62
  )
67
- if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
63
+ if len(signatures) == 1 and sig.name == model.DEFAULT_SIGNATURE_NAME:
68
64
  # User does not specify any signature names explicitly.
69
65
  message = message.format(sig_name="")
70
66
  else:
71
67
  message = message.format(sig_name=f'"{sig.name}" ')
72
68
 
73
- logging.warn(message)
69
+ logging.warning(message)
74
70
 
75
71
 
76
72
  def convert_signatures(
77
- signatures: list[cutils.Signature],
73
+ signatures: list[signature.Signature],
78
74
  *,
79
75
  quant_config: Optional[qcfg.QuantConfig] = None,
80
- _tfl_converter_flags: dict = {},
76
+ _tfl_converter_flags: Optional[dict[str, Any]],
81
77
  ) -> model.TfLiteModel:
82
- """Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
78
+ """Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
79
+
83
80
  Args:
84
- signatures: The list of 'Signature' objects containing PyTorch modules to be converted.
81
+ signatures: The list of 'signature.Signature' objects containing PyTorch
82
+ modules to be converted.
85
83
  quant_config: User-defined quantization method and scheme of the model.
86
- _tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
84
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the
85
+ underlying tflite converter.
86
+
87
+ Returns:
88
+ The converted `model.TfLiteModel` object.
87
89
  """
90
+ if _tfl_converter_flags is None:
91
+ _tfl_converter_flags = {}
92
+
88
93
  _warn_training_modules(signatures)
89
94
 
90
- exported_programs: torch.export.ExportedProgram = [
95
+ exported_programs: torch.export.torch.export.ExportedProgram = [
91
96
  torch.export.export(
92
97
  sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
93
98
  )
@@ -96,23 +101,8 @@ def convert_signatures(
96
101
 
97
102
  # Apply default fx passes
98
103
  exported_programs = list(map(_run_convert_passes, exported_programs))
99
- shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
100
- cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
101
- for exported, sig in zip(exported_programs, signatures)
102
- ]
103
-
104
- merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
105
- cutils.merge_stablehlo_bundles(
106
- shlo_bundles, signatures, exported_programs
107
- )
108
- )
109
- del exported_programs
110
- del shlo_bundles
111
-
112
- gc.collect()
113
-
114
- tflite_model = cutils.convert_stablehlo_to_tflite(
115
- merged_shlo_graph_module,
104
+ tflite_model = lowertools.exported_programs_to_tflite(
105
+ exported_programs,
116
106
  signatures,
117
107
  quant_config=quant_config,
118
108
  _tfl_converter_flags=_tfl_converter_flags,
@@ -0,0 +1,64 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from typing import Any
17
+
18
+ from ai_edge_torch.quantize import quant_config as qcfg
19
+ import tensorflow as tf
20
+
21
+
22
+ def apply_tfl_converter_flags(
23
+ converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict[str, Any]
24
+ ):
25
+ """Applies TFLite converter flags to the converter.
26
+
27
+ Args:
28
+ converter: TFLite converter.
29
+ tfl_converter_flags: TFLite converter flags.
30
+ """
31
+
32
+ def _set_converter_flag(path: list[Any]):
33
+ if len(path) < 2:
34
+ raise ValueError("Expecting at least two values in the path.")
35
+
36
+ target_obj = converter
37
+ for idx in range(len(path) - 2):
38
+ target_obj = getattr(target_obj, path[idx])
39
+
40
+ setattr(target_obj, path[-2], path[-1])
41
+
42
+ def _iterate_dict_tree(flags_dict: dict[str, Any], path: list[Any]):
43
+ for key, value in flags_dict.items():
44
+ path.append(key)
45
+ if isinstance(value, dict):
46
+ _iterate_dict_tree(value, path)
47
+ else:
48
+ path.append(value)
49
+ _set_converter_flag(path)
50
+ path.pop()
51
+ path.pop()
52
+
53
+ _iterate_dict_tree(tfl_converter_flags, [])
54
+
55
+
56
+ def set_tfl_converter_quant_flags(
57
+ converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
58
+ ):
59
+ if quant_config is not None:
60
+ quantizer_mode = quant_config._quantizer_mode
61
+ if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
62
+ converter._experimental_qdq_conversion_mode = "DYNAMIC"
63
+ elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
64
+ converter._experimental_qdq_conversion_mode = "STATIC"
@@ -15,19 +15,23 @@
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
- from typing import Any, Dict, Optional, Tuple, Union
18
+ from typing import Any, Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch import model
21
- from ai_edge_torch.convert import conversion
22
- from ai_edge_torch.convert import conversion_utils as cutils
21
+ from ai_edge_torch._convert import conversion
22
+ from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.quantize import quant_config as qcfg
24
24
  import torch
25
25
 
26
26
 
27
27
  class Converter:
28
+ """A converter for converting PyTorch models to edge models.
29
+
30
+ This class allows adding multiple signatures to the converted edge model.
31
+ """
28
32
 
29
33
  def __init__(self):
30
- self._signatures: list[cutils.Signature] = []
34
+ self._signatures: list[signature_module.Signature] = []
31
35
 
32
36
  def signature(
33
37
  self,
@@ -36,9 +40,9 @@ class Converter:
36
40
  sample_args=None,
37
41
  sample_kwargs=None,
38
42
  *,
39
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
43
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
40
44
  ) -> Converter:
41
- """Alias to `add_signature`"""
45
+ """Functions as an alias to `add_signature`."""
42
46
  return self.add_signature(
43
47
  name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
44
48
  )
@@ -50,17 +54,24 @@ class Converter:
50
54
  sample_args=None,
51
55
  sample_kwargs=None,
52
56
  *,
53
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
57
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
54
58
  ) -> Converter:
55
59
  """Allows adding a new named torch model along with sample args to the conversion.
56
60
 
57
61
  Args:
58
62
  name: The name of the signature included in the converted edge model.
59
63
  module: The torch module to be converted.
60
- sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
61
- sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
62
- dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
63
- See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
64
+ sample_args: Tuple of tensors by which the torch module will be traced
65
+ with prior to conversion.
66
+ sample_kwargs: Dict of str to tensor by which the torch module will be
67
+ traced with prior to conversion.
68
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
69
+ specifications for each input in original order. See
70
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
71
+ details.
72
+
73
+ Returns:
74
+ The converter object itself.
64
75
 
65
76
  Raises:
66
77
  ValueError: If a signature with the provided name already exists.
@@ -75,7 +86,7 @@ class Converter:
75
86
  raise ValueError("sample_args or sample_kwargs must be provided.")
76
87
 
77
88
  self._signatures.append(
78
- cutils.Signature(
89
+ signature_module.Signature(
79
90
  name,
80
91
  module,
81
92
  sample_args,
@@ -92,8 +103,8 @@ class Converter:
92
103
  sample_kwargs=None,
93
104
  *,
94
105
  quant_config: Optional[qcfg.QuantConfig] = None,
95
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
96
- _ai_edge_converter_flags: dict = {},
106
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
107
+ _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
97
108
  ) -> model.TfLiteModel:
98
109
  """Finalizes the conversion and produces an edge model.
99
110
 
@@ -101,31 +112,44 @@ class Converter:
101
112
 
102
113
  edge_model = Converter().signature(name, module, args).convert()
103
114
 
104
- Or it could be used to set the default signature for the converted edge model:
115
+ Or it could be used to set the default signature for the converted edge
116
+ model:
105
117
 
106
118
  edge_model = Converter().convert(module, args)
107
119
 
108
120
  Args:
109
- name: The name of the signature included in the converted edge model.
110
121
  module: The torch module to be converted.
111
- sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
112
- sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
122
+ sample_args: Tuple of tensors by which the torch module will be traced
123
+ with prior to conversion.
124
+ sample_kwargs: Dict of str to tensor by which the torch module will be
125
+ traced with prior to conversion.
113
126
  quant_config: User-defined quantization method and scheme of the model.
114
- dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
115
- See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
116
- _ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
117
- This gives access to an implementation detail of this function and so needs to be treated as such.
118
- Please do not rely on this parameter except for local debugging as this can be removed in a future release.
127
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
128
+ specifications for each input in original order. See
129
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
130
+ details.
131
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for
132
+ the underlying converter. This gives access to an implementation detail
133
+ of this function and so needs to be treated as such. Please do not rely
134
+ on this parameter except for local debugging as this can be removed in a
135
+ future release.
136
+
137
+ Returns:
138
+ The converted edge model.
119
139
 
120
140
  Raises:
121
- ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
141
+ ValueError: If the arguments are not provided as expected. See the example
142
+ in this functions's comment.
122
143
  """
144
+ if _ai_edge_converter_flags is None:
145
+ _ai_edge_converter_flags = {}
146
+
123
147
  if module is not None:
124
148
  if (
125
149
  sample_args is not None or sample_kwargs is not None
126
150
  ): # both module and args provided
127
151
  self.add_signature(
128
- cutils.DEFAULT_SIGNATURE_NAME,
152
+ model.DEFAULT_SIGNATURE_NAME,
129
153
  module,
130
154
  sample_args,
131
155
  sample_kwargs,
@@ -136,7 +160,6 @@ class Converter:
136
160
  "sample_args or sample_kwargs must be provided if a module is"
137
161
  " specified."
138
162
  )
139
-
140
163
  return conversion.convert_signatures(
141
164
  self._signatures,
142
165
  quant_config=quant_config,
@@ -149,22 +172,28 @@ def signature(
149
172
  module: torch.nn.Module,
150
173
  sample_args=None,
151
174
  sample_kwargs=None,
152
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
175
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
153
176
  ) -> Converter:
154
177
  """Initiates a Converter object with the provided signature.
155
178
 
156
179
  Args:
157
180
  name: The name of the signature included in the converted edge model.
158
181
  module: The torch module to be converted.
159
- sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
160
- sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
161
- dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
162
- See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
182
+ sample_args: Tuple of tensors by which the torch module will be traced with
183
+ prior to conversion.
184
+ sample_kwargs: Dict of str to tensor by which the torch module will be
185
+ traced with prior to conversion.
186
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
187
+ specifications for each input in original order. See
188
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
189
+ details.
190
+
191
+ Returns:
192
+ A Converter object with the provided signature.
163
193
 
164
194
  Example:
165
195
  converter = ai_edge_torch.signature(name, module, args)
166
196
  edge_model = converter.convert()
167
-
168
197
  """
169
198
  return Converter().signature(
170
199
  name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
@@ -177,27 +206,38 @@ def convert(
177
206
  sample_kwargs=None,
178
207
  *,
179
208
  quant_config: Optional[qcfg.QuantConfig] = None,
180
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
181
- _ai_edge_converter_flags: dict = {},
209
+ dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
210
+ _ai_edge_converter_flags: Optional[dict[str, Any]] = None,
182
211
  ) -> model.TfLiteModel:
183
- """Allows converting a PyTorch model to an edge model with one default signature in one step.
212
+ """Converts a PyTorch model to an edge model with a default signature.
184
213
 
185
214
  Args:
186
215
  module: The torch module to be converted.
187
- sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
188
- sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
216
+ sample_args: Tuple of tensors by which the torch module will be traced with
217
+ prior to conversion.
218
+ sample_kwargs: Dict of str to tensor by which the torch module will be
219
+ traced with prior to conversion.
189
220
  quant_config: User-defined quantization method and scheme of the model.
190
- dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
191
- See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
192
- _ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
193
- This gives access to an implementation detail of this function and so needs to be treated as such.
194
- Please do not rely on this parameter except for local debugging as this can be removed in a future release.
221
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape
222
+ specifications for each input in original order. See
223
+ https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
224
+ details.
225
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for the
226
+ underlying converter. This gives access to an implementation detail of
227
+ this function and so needs to be treated as such. Please do not rely on
228
+ this parameter except for local debugging as this can be removed in a
229
+ future release.
230
+
231
+ Returns:
232
+ The converted edge model.
195
233
 
196
234
  Example:
197
235
  edge_model = ai_edge_torch.convert(module, args)
198
-
199
236
  """
200
237
 
238
+ if _ai_edge_converter_flags is None:
239
+ _ai_edge_converter_flags = {}
240
+
201
241
  return Converter().convert(
202
242
  module,
203
243
  sample_args,
@@ -15,15 +15,15 @@
15
15
 
16
16
  from typing import Sequence, Union
17
17
 
18
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassBase
19
- from ai_edge_torch.convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
- from ai_edge_torch.convert.fx_passes._pass_base import FxPassBase
21
- from ai_edge_torch.convert.fx_passes._pass_base import FxPassResult
22
- from ai_edge_torch.convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
23
- from ai_edge_torch.convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
24
- from ai_edge_torch.convert.fx_passes.canonicalize_pass import CanonicalizePass
25
- from ai_edge_torch.convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
26
- from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
18
+ from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
19
+ from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
20
+ from ai_edge_torch._convert.fx_passes._pass_base import FxPassBase
21
+ from ai_edge_torch._convert.fx_passes._pass_base import FxPassResult
22
+ from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
23
+ from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
24
+ from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
25
+ from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
26
+ from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
27
27
  from torch.export import ExportedProgram
28
28
  from torch.fx.passes.infra.pass_manager import pass_result_wrapper
29
29
  import torch.utils._pytree as pytree
@@ -13,27 +13,23 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import copy
17
- import functools
18
16
  from functools import reduce
19
17
  from typing import Any, Callable
20
-
21
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
18
+ from ai_edge_torch import lowertools
22
19
  import torch
23
- from torch.fx import GraphModule
24
- from torch.fx import Node
25
- from torch.fx.passes.infra.pass_base import PassBase
26
- from torch.fx.passes.infra.pass_base import PassResult
20
+ from torch.fx.passes.infra import pass_base
27
21
  import torch.utils._pytree as pytree
28
22
 
29
- _composite_builders: dict[Callable, Callable[[GraphModule, Node], None]] = {}
23
+ _composite_builders: dict[
24
+ Callable, Callable[[torch.fx.GraphModule, torch.fx.Node], None]
25
+ ] = {}
30
26
 
31
27
 
32
28
  def _register_composite_builder(op):
33
29
  def inner(func):
34
30
  if isinstance(op, torch._ops.OpOverloadPacket):
35
- for overload in v.overloads():
36
- _composite_builders[getattr(v, overload)] = func
31
+ for overload in op.overloads():
32
+ _composite_builders[getattr(op, overload)] = func
37
33
  else:
38
34
  _composite_builders[op] = func
39
35
  return func
@@ -44,6 +40,19 @@ def _register_composite_builder(op):
44
40
  def _tree_map_to_composite_attr_values(
45
41
  values, *, stringify_incompatible_values=True
46
42
  ):
43
+ """Convert a tree of values to a tree of composite attribute values.
44
+
45
+ This is used for pre-processing op attributes before passing them to
46
+ the composite op as attributes.
47
+
48
+ Args:
49
+ values: A tree of values.
50
+ stringify_incompatible_values: If True, stringify values that are not
51
+ compatible with composite attributes.
52
+
53
+ Returns:
54
+ A tree of composite attribute values.
55
+ """
47
56
 
48
57
  def convert(value):
49
58
  nonlocal stringify_incompatible_values
@@ -60,6 +69,11 @@ def _tree_map_to_composite_attr_values(
60
69
 
61
70
 
62
71
  class TorchOpArgumentsMapper:
72
+ """A helper class to map op arguments to kwargs.
73
+
74
+ This is mainly used to extract the default values for op arguments and present
75
+ all arguments as kwargs.
76
+ """
63
77
 
64
78
  def __init__(self, op):
65
79
  if isinstance(op, torch._ops.OpOverloadPacket):
@@ -72,13 +86,21 @@ class TorchOpArgumentsMapper:
72
86
  ]
73
87
 
74
88
  def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
75
- """Inspect the op's schema and extract all its args and kwargs
76
- into one single kwargs dict, with default values for those
77
- unspecified args and kwargs.
89
+ """Extracts all arguments of the op as kwargs.
90
+
91
+ Inspect the op's schema and extract all its args and kwargs into one single
92
+ kwargs dict, with default values for those unspecified args and kwargs.
93
+
94
+ Args:
95
+ args: The op's arguments.
96
+ kwargs: The op's kwargs.
97
+
98
+ Returns:
99
+ A kwargs dict with all args and kwargs.
78
100
  """
79
101
  full_kwargs = {**(kwargs or {})}
80
102
 
81
- for arg, (name, default_value) in zip(args, self.arg_specs):
103
+ for arg, (name, _) in zip(args, self.arg_specs):
82
104
  full_kwargs[name] = arg
83
105
 
84
106
  for name, default_value in self.arg_specs[len(args) :]:
@@ -89,12 +111,13 @@ class TorchOpArgumentsMapper:
89
111
 
90
112
 
91
113
  @_register_composite_builder(torch.ops.aten.hardswish.default)
92
- def _aten_hardswish(gm: GraphModule, node: Node):
114
+ def _aten_hardswish(_: torch.fx.GraphModule, node: torch.fx.Node):
115
+ """Build a composite for aten.hardswish.default."""
93
116
  op = node.target
94
117
 
95
118
  def hardswish(self: torch.Tensor):
96
119
  nonlocal op
97
- builder = StableHLOCompositeBuilder("aten.hardswish.default")
120
+ builder = lowertools.StableHLOCompositeBuilder("aten.hardswish.default")
98
121
  self = builder.mark_inputs(self)
99
122
  output = op(self)
100
123
  output = builder.mark_outputs(output)
@@ -104,7 +127,8 @@ def _aten_hardswish(gm: GraphModule, node: Node):
104
127
 
105
128
 
106
129
  @_register_composite_builder(torch.ops.aten.gelu.default)
107
- def _aten_gelu(gm: GraphModule, node: Node):
130
+ def _aten_gelu(_: torch.fx.GraphModule, node: torch.fx.Node):
131
+ """Build a composite for aten.gelu.default."""
108
132
  op = node.target
109
133
  args_mapper = TorchOpArgumentsMapper(op)
110
134
 
@@ -120,7 +144,7 @@ def _aten_gelu(gm: GraphModule, node: Node):
120
144
  ):
121
145
  return op(*args, **kwargs)
122
146
 
123
- builder = StableHLOCompositeBuilder(
147
+ builder = lowertools.StableHLOCompositeBuilder(
124
148
  "aten.gelu.default",
125
149
  attr=_tree_map_to_composite_attr_values({
126
150
  "approximate": full_kwargs["approximate"],
@@ -135,7 +159,8 @@ def _aten_gelu(gm: GraphModule, node: Node):
135
159
 
136
160
 
137
161
  @_register_composite_builder(torch.ops.aten.avg_pool2d.default)
138
- def _aten_avg_pool2d(gm: GraphModule, node: Node):
162
+ def _aten_avg_pool2d(_: torch.fx.GraphModule, node: torch.fx.Node):
163
+ """Build a composite for aten.avg_pool2d.default."""
139
164
  op = node.target
140
165
  args_mapper = TorchOpArgumentsMapper(op)
141
166
 
@@ -199,7 +224,7 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
199
224
  ):
200
225
  return op(*args, **kwargs)
201
226
 
202
- builder = StableHLOCompositeBuilder(
227
+ builder = lowertools.StableHLOCompositeBuilder(
203
228
  "aten.avg_pool2d.default",
204
229
  attr=_tree_map_to_composite_attr_values({
205
230
  "kernel_size": full_kwargs["kernel_size"],
@@ -220,7 +245,7 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
220
245
 
221
246
 
222
247
  @_register_composite_builder(torch.ops.aten.embedding.default)
223
- def _aten_embedding(gm: GraphModule, node: Node):
248
+ def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
224
249
  op = node.target
225
250
  args_mapper = TorchOpArgumentsMapper(op)
226
251
 
@@ -237,7 +262,7 @@ def _aten_embedding(gm: GraphModule, node: Node):
237
262
  # Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
238
263
  idx = torch.reshape(idx, (idx.numel(),))
239
264
 
240
- builder = StableHLOCompositeBuilder("odml.embedding_lookup")
265
+ builder = lowertools.StableHLOCompositeBuilder("odml.embedding_lookup")
241
266
  full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
242
267
  idx,
243
268
  full_kwargs["weight"],
@@ -252,13 +277,13 @@ def _aten_embedding(gm: GraphModule, node: Node):
252
277
  node.target = embedding
253
278
 
254
279
 
255
- class BuildAtenCompositePass(PassBase):
280
+ class BuildAtenCompositePass(pass_base.PassBase):
256
281
 
257
- def call(self, graph_module: GraphModule):
282
+ def call(self, graph_module: torch.fx.GraphModule):
258
283
  for node in graph_module.graph.nodes:
259
284
  if node.target in _composite_builders:
260
285
  _composite_builders[node.target](graph_module, node)
261
286
 
262
287
  graph_module.graph.lint()
263
288
  graph_module.recompile()
264
- return PassResult(graph_module, True)
289
+ return pass_base.PassResult(graph_module, True)