ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
- ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
- ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
ai_edge_torch/__init__.py
CHANGED
|
@@ -13,11 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
18
|
-
from .
|
|
19
|
-
from .model import Model
|
|
20
|
-
from .version import __version__
|
|
16
|
+
from ai_edge_torch._convert.converter import convert
|
|
17
|
+
from ai_edge_torch._convert.converter import signature
|
|
18
|
+
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
|
|
19
|
+
from ai_edge_torch.model import Model
|
|
20
|
+
from ai_edge_torch.version import __version__
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
def load(path: str) -> Model:
|
|
@@ -13,48 +13,44 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import gc
|
|
17
16
|
import logging
|
|
18
17
|
import os
|
|
19
|
-
from typing import Optional
|
|
18
|
+
from typing import Any, Optional
|
|
20
19
|
|
|
20
|
+
from ai_edge_torch import lowertools
|
|
21
21
|
from ai_edge_torch import model
|
|
22
|
-
from ai_edge_torch.
|
|
23
|
-
from ai_edge_torch.
|
|
24
|
-
from ai_edge_torch.
|
|
25
|
-
from ai_edge_torch.convert.fx_passes import CanonicalizePass
|
|
26
|
-
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
|
|
27
|
-
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
|
|
28
|
-
from ai_edge_torch.convert.fx_passes import run_passes
|
|
29
|
-
from ai_edge_torch.generative.fx_passes import run_generative_passes
|
|
22
|
+
from ai_edge_torch._convert import fx_passes
|
|
23
|
+
from ai_edge_torch._convert import signature
|
|
24
|
+
from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
30
25
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
31
26
|
import torch
|
|
32
|
-
from torch.export import ExportedProgram
|
|
33
|
-
from torch_xla import stablehlo
|
|
34
27
|
|
|
35
28
|
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
|
36
29
|
|
|
37
30
|
|
|
38
31
|
def _run_convert_passes(
|
|
39
|
-
exported_program: ExportedProgram,
|
|
40
|
-
) -> ExportedProgram:
|
|
41
|
-
exported_program = run_generative_passes(
|
|
42
|
-
|
|
32
|
+
exported_program: torch.export.ExportedProgram,
|
|
33
|
+
) -> torch.export.ExportedProgram:
|
|
34
|
+
exported_program = generative_fx_passes.run_generative_passes(
|
|
35
|
+
exported_program
|
|
36
|
+
)
|
|
37
|
+
return fx_passes.run_passes(
|
|
43
38
|
exported_program,
|
|
44
39
|
[
|
|
45
|
-
BuildInterpolateCompositePass(),
|
|
46
|
-
CanonicalizePass(),
|
|
47
|
-
OptimizeLayoutTransposesPass(),
|
|
48
|
-
CanonicalizePass(),
|
|
49
|
-
BuildAtenCompositePass(),
|
|
50
|
-
CanonicalizePass(),
|
|
51
|
-
InjectMlirDebuginfoPass(),
|
|
52
|
-
CanonicalizePass(),
|
|
40
|
+
fx_passes.BuildInterpolateCompositePass(),
|
|
41
|
+
fx_passes.CanonicalizePass(),
|
|
42
|
+
fx_passes.OptimizeLayoutTransposesPass(),
|
|
43
|
+
fx_passes.CanonicalizePass(),
|
|
44
|
+
fx_passes.BuildAtenCompositePass(),
|
|
45
|
+
fx_passes.CanonicalizePass(),
|
|
46
|
+
fx_passes.InjectMlirDebuginfoPass(),
|
|
47
|
+
fx_passes.CanonicalizePass(),
|
|
53
48
|
],
|
|
54
49
|
)
|
|
55
50
|
|
|
56
51
|
|
|
57
|
-
def _warn_training_modules(signatures: list[
|
|
52
|
+
def _warn_training_modules(signatures: list[signature.Signature]):
|
|
53
|
+
"""Warns the user if the module is in training mode (.eval not called)."""
|
|
58
54
|
for sig in signatures:
|
|
59
55
|
if not sig.module.training:
|
|
60
56
|
continue
|
|
@@ -64,30 +60,39 @@ def _warn_training_modules(signatures: list[cutils.Signature]):
|
|
|
64
60
|
" module in evaluation mode with `module.eval()` for better on-device"
|
|
65
61
|
" performance and compatibility."
|
|
66
62
|
)
|
|
67
|
-
if len(signatures) == 1 and sig.name ==
|
|
63
|
+
if len(signatures) == 1 and sig.name == model.DEFAULT_SIGNATURE_NAME:
|
|
68
64
|
# User does not specify any signature names explicitly.
|
|
69
65
|
message = message.format(sig_name="")
|
|
70
66
|
else:
|
|
71
67
|
message = message.format(sig_name=f'"{sig.name}" ')
|
|
72
68
|
|
|
73
|
-
logging.
|
|
69
|
+
logging.warning(message)
|
|
74
70
|
|
|
75
71
|
|
|
76
72
|
def convert_signatures(
|
|
77
|
-
signatures: list[
|
|
73
|
+
signatures: list[signature.Signature],
|
|
78
74
|
*,
|
|
79
75
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
80
|
-
_tfl_converter_flags: dict
|
|
76
|
+
_tfl_converter_flags: Optional[dict[str, Any]],
|
|
81
77
|
) -> model.TfLiteModel:
|
|
82
|
-
"""Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
|
|
78
|
+
"""Converts a list of `signature.Signature`s and embeds them into one `model.TfLiteModel`.
|
|
79
|
+
|
|
83
80
|
Args:
|
|
84
|
-
signatures: The list of 'Signature' objects containing PyTorch
|
|
81
|
+
signatures: The list of 'signature.Signature' objects containing PyTorch
|
|
82
|
+
modules to be converted.
|
|
85
83
|
quant_config: User-defined quantization method and scheme of the model.
|
|
86
|
-
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
|
84
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
|
85
|
+
underlying tflite converter.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
The converted `model.TfLiteModel` object.
|
|
87
89
|
"""
|
|
90
|
+
if _tfl_converter_flags is None:
|
|
91
|
+
_tfl_converter_flags = {}
|
|
92
|
+
|
|
88
93
|
_warn_training_modules(signatures)
|
|
89
94
|
|
|
90
|
-
exported_programs: torch.export.ExportedProgram = [
|
|
95
|
+
exported_programs: torch.export.torch.export.ExportedProgram = [
|
|
91
96
|
torch.export.export(
|
|
92
97
|
sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes
|
|
93
98
|
)
|
|
@@ -96,23 +101,8 @@ def convert_signatures(
|
|
|
96
101
|
|
|
97
102
|
# Apply default fx passes
|
|
98
103
|
exported_programs = list(map(_run_convert_passes, exported_programs))
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
for exported, sig in zip(exported_programs, signatures)
|
|
102
|
-
]
|
|
103
|
-
|
|
104
|
-
merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
|
|
105
|
-
cutils.merge_stablehlo_bundles(
|
|
106
|
-
shlo_bundles, signatures, exported_programs
|
|
107
|
-
)
|
|
108
|
-
)
|
|
109
|
-
del exported_programs
|
|
110
|
-
del shlo_bundles
|
|
111
|
-
|
|
112
|
-
gc.collect()
|
|
113
|
-
|
|
114
|
-
tflite_model = cutils.convert_stablehlo_to_tflite(
|
|
115
|
-
merged_shlo_graph_module,
|
|
104
|
+
tflite_model = lowertools.exported_programs_to_tflite(
|
|
105
|
+
exported_programs,
|
|
116
106
|
signatures,
|
|
117
107
|
quant_config=quant_config,
|
|
118
108
|
_tfl_converter_flags=_tfl_converter_flags,
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from typing import Any
|
|
17
|
+
|
|
18
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
19
|
+
import tensorflow as tf
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def apply_tfl_converter_flags(
|
|
23
|
+
converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict[str, Any]
|
|
24
|
+
):
|
|
25
|
+
"""Applies TFLite converter flags to the converter.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
converter: TFLite converter.
|
|
29
|
+
tfl_converter_flags: TFLite converter flags.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def _set_converter_flag(path: list[Any]):
|
|
33
|
+
if len(path) < 2:
|
|
34
|
+
raise ValueError("Expecting at least two values in the path.")
|
|
35
|
+
|
|
36
|
+
target_obj = converter
|
|
37
|
+
for idx in range(len(path) - 2):
|
|
38
|
+
target_obj = getattr(target_obj, path[idx])
|
|
39
|
+
|
|
40
|
+
setattr(target_obj, path[-2], path[-1])
|
|
41
|
+
|
|
42
|
+
def _iterate_dict_tree(flags_dict: dict[str, Any], path: list[Any]):
|
|
43
|
+
for key, value in flags_dict.items():
|
|
44
|
+
path.append(key)
|
|
45
|
+
if isinstance(value, dict):
|
|
46
|
+
_iterate_dict_tree(value, path)
|
|
47
|
+
else:
|
|
48
|
+
path.append(value)
|
|
49
|
+
_set_converter_flag(path)
|
|
50
|
+
path.pop()
|
|
51
|
+
path.pop()
|
|
52
|
+
|
|
53
|
+
_iterate_dict_tree(tfl_converter_flags, [])
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def set_tfl_converter_quant_flags(
|
|
57
|
+
converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
|
|
58
|
+
):
|
|
59
|
+
if quant_config is not None:
|
|
60
|
+
quantizer_mode = quant_config._quantizer_mode
|
|
61
|
+
if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
|
|
62
|
+
converter._experimental_qdq_conversion_mode = "DYNAMIC"
|
|
63
|
+
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
|
|
64
|
+
converter._experimental_qdq_conversion_mode = "STATIC"
|
|
@@ -15,19 +15,23 @@
|
|
|
15
15
|
|
|
16
16
|
from __future__ import annotations
|
|
17
17
|
|
|
18
|
-
from typing import Any,
|
|
18
|
+
from typing import Any, Optional, Tuple, Union
|
|
19
19
|
|
|
20
20
|
from ai_edge_torch import model
|
|
21
|
-
from ai_edge_torch.
|
|
22
|
-
from ai_edge_torch.
|
|
21
|
+
from ai_edge_torch._convert import conversion
|
|
22
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
23
23
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
24
24
|
import torch
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class Converter:
|
|
28
|
+
"""A converter for converting PyTorch models to edge models.
|
|
29
|
+
|
|
30
|
+
This class allows adding multiple signatures to the converted edge model.
|
|
31
|
+
"""
|
|
28
32
|
|
|
29
33
|
def __init__(self):
|
|
30
|
-
self._signatures: list[
|
|
34
|
+
self._signatures: list[signature_module.Signature] = []
|
|
31
35
|
|
|
32
36
|
def signature(
|
|
33
37
|
self,
|
|
@@ -36,9 +40,9 @@ class Converter:
|
|
|
36
40
|
sample_args=None,
|
|
37
41
|
sample_kwargs=None,
|
|
38
42
|
*,
|
|
39
|
-
dynamic_shapes: Optional[Union[
|
|
43
|
+
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
|
40
44
|
) -> Converter:
|
|
41
|
-
"""
|
|
45
|
+
"""Functions as an alias to `add_signature`."""
|
|
42
46
|
return self.add_signature(
|
|
43
47
|
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
44
48
|
)
|
|
@@ -50,17 +54,24 @@ class Converter:
|
|
|
50
54
|
sample_args=None,
|
|
51
55
|
sample_kwargs=None,
|
|
52
56
|
*,
|
|
53
|
-
dynamic_shapes: Optional[Union[
|
|
57
|
+
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
|
54
58
|
) -> Converter:
|
|
55
59
|
"""Allows adding a new named torch model along with sample args to the conversion.
|
|
56
60
|
|
|
57
61
|
Args:
|
|
58
62
|
name: The name of the signature included in the converted edge model.
|
|
59
63
|
module: The torch module to be converted.
|
|
60
|
-
sample_args: Tuple of tensors by which the torch module will be traced
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
+
sample_args: Tuple of tensors by which the torch module will be traced
|
|
65
|
+
with prior to conversion.
|
|
66
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be
|
|
67
|
+
traced with prior to conversion.
|
|
68
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
|
69
|
+
specifications for each input in original order. See
|
|
70
|
+
https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
|
|
71
|
+
details.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The converter object itself.
|
|
64
75
|
|
|
65
76
|
Raises:
|
|
66
77
|
ValueError: If a signature with the provided name already exists.
|
|
@@ -75,7 +86,7 @@ class Converter:
|
|
|
75
86
|
raise ValueError("sample_args or sample_kwargs must be provided.")
|
|
76
87
|
|
|
77
88
|
self._signatures.append(
|
|
78
|
-
|
|
89
|
+
signature_module.Signature(
|
|
79
90
|
name,
|
|
80
91
|
module,
|
|
81
92
|
sample_args,
|
|
@@ -92,8 +103,8 @@ class Converter:
|
|
|
92
103
|
sample_kwargs=None,
|
|
93
104
|
*,
|
|
94
105
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
95
|
-
dynamic_shapes: Optional[Union[
|
|
96
|
-
_ai_edge_converter_flags: dict =
|
|
106
|
+
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
|
107
|
+
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
|
97
108
|
) -> model.TfLiteModel:
|
|
98
109
|
"""Finalizes the conversion and produces an edge model.
|
|
99
110
|
|
|
@@ -101,31 +112,44 @@ class Converter:
|
|
|
101
112
|
|
|
102
113
|
edge_model = Converter().signature(name, module, args).convert()
|
|
103
114
|
|
|
104
|
-
Or it could be used to set the default signature for the converted edge
|
|
115
|
+
Or it could be used to set the default signature for the converted edge
|
|
116
|
+
model:
|
|
105
117
|
|
|
106
118
|
edge_model = Converter().convert(module, args)
|
|
107
119
|
|
|
108
120
|
Args:
|
|
109
|
-
name: The name of the signature included in the converted edge model.
|
|
110
121
|
module: The torch module to be converted.
|
|
111
|
-
sample_args: Tuple of tensors by which the torch module will be traced
|
|
112
|
-
|
|
122
|
+
sample_args: Tuple of tensors by which the torch module will be traced
|
|
123
|
+
with prior to conversion.
|
|
124
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be
|
|
125
|
+
traced with prior to conversion.
|
|
113
126
|
quant_config: User-defined quantization method and scheme of the model.
|
|
114
|
-
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
127
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
|
128
|
+
specifications for each input in original order. See
|
|
129
|
+
https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
|
|
130
|
+
details.
|
|
131
|
+
_ai_edge_converter_flags: A nested dictionary allowing setting flags for
|
|
132
|
+
the underlying converter. This gives access to an implementation detail
|
|
133
|
+
of this function and so needs to be treated as such. Please do not rely
|
|
134
|
+
on this parameter except for local debugging as this can be removed in a
|
|
135
|
+
future release.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The converted edge model.
|
|
119
139
|
|
|
120
140
|
Raises:
|
|
121
|
-
ValueError: If the arguments are not provided as expected. See the example
|
|
141
|
+
ValueError: If the arguments are not provided as expected. See the example
|
|
142
|
+
in this functions's comment.
|
|
122
143
|
"""
|
|
144
|
+
if _ai_edge_converter_flags is None:
|
|
145
|
+
_ai_edge_converter_flags = {}
|
|
146
|
+
|
|
123
147
|
if module is not None:
|
|
124
148
|
if (
|
|
125
149
|
sample_args is not None or sample_kwargs is not None
|
|
126
150
|
): # both module and args provided
|
|
127
151
|
self.add_signature(
|
|
128
|
-
|
|
152
|
+
model.DEFAULT_SIGNATURE_NAME,
|
|
129
153
|
module,
|
|
130
154
|
sample_args,
|
|
131
155
|
sample_kwargs,
|
|
@@ -136,7 +160,6 @@ class Converter:
|
|
|
136
160
|
"sample_args or sample_kwargs must be provided if a module is"
|
|
137
161
|
" specified."
|
|
138
162
|
)
|
|
139
|
-
|
|
140
163
|
return conversion.convert_signatures(
|
|
141
164
|
self._signatures,
|
|
142
165
|
quant_config=quant_config,
|
|
@@ -149,22 +172,28 @@ def signature(
|
|
|
149
172
|
module: torch.nn.Module,
|
|
150
173
|
sample_args=None,
|
|
151
174
|
sample_kwargs=None,
|
|
152
|
-
dynamic_shapes: Optional[Union[
|
|
175
|
+
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
|
153
176
|
) -> Converter:
|
|
154
177
|
"""Initiates a Converter object with the provided signature.
|
|
155
178
|
|
|
156
179
|
Args:
|
|
157
180
|
name: The name of the signature included in the converted edge model.
|
|
158
181
|
module: The torch module to be converted.
|
|
159
|
-
sample_args: Tuple of tensors by which the torch module will be traced with
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
182
|
+
sample_args: Tuple of tensors by which the torch module will be traced with
|
|
183
|
+
prior to conversion.
|
|
184
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be
|
|
185
|
+
traced with prior to conversion.
|
|
186
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
|
187
|
+
specifications for each input in original order. See
|
|
188
|
+
https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
|
|
189
|
+
details.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
A Converter object with the provided signature.
|
|
163
193
|
|
|
164
194
|
Example:
|
|
165
195
|
converter = ai_edge_torch.signature(name, module, args)
|
|
166
196
|
edge_model = converter.convert()
|
|
167
|
-
|
|
168
197
|
"""
|
|
169
198
|
return Converter().signature(
|
|
170
199
|
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
@@ -177,27 +206,38 @@ def convert(
|
|
|
177
206
|
sample_kwargs=None,
|
|
178
207
|
*,
|
|
179
208
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
180
|
-
dynamic_shapes: Optional[Union[
|
|
181
|
-
_ai_edge_converter_flags: dict =
|
|
209
|
+
dynamic_shapes: Optional[Union[dict[str, Any], Tuple[Any, ...]]] = None,
|
|
210
|
+
_ai_edge_converter_flags: Optional[dict[str, Any]] = None,
|
|
182
211
|
) -> model.TfLiteModel:
|
|
183
|
-
"""
|
|
212
|
+
"""Converts a PyTorch model to an edge model with a default signature.
|
|
184
213
|
|
|
185
214
|
Args:
|
|
186
215
|
module: The torch module to be converted.
|
|
187
|
-
sample_args: Tuple of tensors by which the torch module will be traced with
|
|
188
|
-
|
|
216
|
+
sample_args: Tuple of tensors by which the torch module will be traced with
|
|
217
|
+
prior to conversion.
|
|
218
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be
|
|
219
|
+
traced with prior to conversion.
|
|
189
220
|
quant_config: User-defined quantization method and scheme of the model.
|
|
190
|
-
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
221
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape
|
|
222
|
+
specifications for each input in original order. See
|
|
223
|
+
https://pytorch.org/docs/stable/export.html#expressing-dynamism for more
|
|
224
|
+
details.
|
|
225
|
+
_ai_edge_converter_flags: A nested dictionary allowing setting flags for the
|
|
226
|
+
underlying converter. This gives access to an implementation detail of
|
|
227
|
+
this function and so needs to be treated as such. Please do not rely on
|
|
228
|
+
this parameter except for local debugging as this can be removed in a
|
|
229
|
+
future release.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
The converted edge model.
|
|
195
233
|
|
|
196
234
|
Example:
|
|
197
235
|
edge_model = ai_edge_torch.convert(module, args)
|
|
198
|
-
|
|
199
236
|
"""
|
|
200
237
|
|
|
238
|
+
if _ai_edge_converter_flags is None:
|
|
239
|
+
_ai_edge_converter_flags = {}
|
|
240
|
+
|
|
201
241
|
return Converter().convert(
|
|
202
242
|
module,
|
|
203
243
|
sample_args,
|
|
@@ -15,15 +15,15 @@
|
|
|
15
15
|
|
|
16
16
|
from typing import Sequence, Union
|
|
17
17
|
|
|
18
|
-
from ai_edge_torch.
|
|
19
|
-
from ai_edge_torch.
|
|
20
|
-
from ai_edge_torch.
|
|
21
|
-
from ai_edge_torch.
|
|
22
|
-
from ai_edge_torch.
|
|
23
|
-
from ai_edge_torch.
|
|
24
|
-
from ai_edge_torch.
|
|
25
|
-
from ai_edge_torch.
|
|
26
|
-
from ai_edge_torch.
|
|
18
|
+
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassBase
|
|
19
|
+
from ai_edge_torch._convert.fx_passes._pass_base import ExportedProgramPassResult # NOQA
|
|
20
|
+
from ai_edge_torch._convert.fx_passes._pass_base import FxPassBase
|
|
21
|
+
from ai_edge_torch._convert.fx_passes._pass_base import FxPassResult
|
|
22
|
+
from ai_edge_torch._convert.fx_passes.build_aten_composite_pass import BuildAtenCompositePass # NOQA
|
|
23
|
+
from ai_edge_torch._convert.fx_passes.build_interpolate_composite_pass import BuildInterpolateCompositePass # NOQA
|
|
24
|
+
from ai_edge_torch._convert.fx_passes.canonicalize_pass import CanonicalizePass
|
|
25
|
+
from ai_edge_torch._convert.fx_passes.inject_mlir_debuginfo_pass import InjectMlirDebuginfoPass # NOQA
|
|
26
|
+
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import OptimizeLayoutTransposesPass # NOQA
|
|
27
27
|
from torch.export import ExportedProgram
|
|
28
28
|
from torch.fx.passes.infra.pass_manager import pass_result_wrapper
|
|
29
29
|
import torch.utils._pytree as pytree
|
|
@@ -13,27 +13,23 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import copy
|
|
17
|
-
import functools
|
|
18
16
|
from functools import reduce
|
|
19
17
|
from typing import Any, Callable
|
|
20
|
-
|
|
21
|
-
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
18
|
+
from ai_edge_torch import lowertools
|
|
22
19
|
import torch
|
|
23
|
-
from torch.fx import
|
|
24
|
-
from torch.fx import Node
|
|
25
|
-
from torch.fx.passes.infra.pass_base import PassBase
|
|
26
|
-
from torch.fx.passes.infra.pass_base import PassResult
|
|
20
|
+
from torch.fx.passes.infra import pass_base
|
|
27
21
|
import torch.utils._pytree as pytree
|
|
28
22
|
|
|
29
|
-
_composite_builders: dict[
|
|
23
|
+
_composite_builders: dict[
|
|
24
|
+
Callable, Callable[[torch.fx.GraphModule, torch.fx.Node], None]
|
|
25
|
+
] = {}
|
|
30
26
|
|
|
31
27
|
|
|
32
28
|
def _register_composite_builder(op):
|
|
33
29
|
def inner(func):
|
|
34
30
|
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
35
|
-
for overload in
|
|
36
|
-
_composite_builders[getattr(
|
|
31
|
+
for overload in op.overloads():
|
|
32
|
+
_composite_builders[getattr(op, overload)] = func
|
|
37
33
|
else:
|
|
38
34
|
_composite_builders[op] = func
|
|
39
35
|
return func
|
|
@@ -44,6 +40,19 @@ def _register_composite_builder(op):
|
|
|
44
40
|
def _tree_map_to_composite_attr_values(
|
|
45
41
|
values, *, stringify_incompatible_values=True
|
|
46
42
|
):
|
|
43
|
+
"""Convert a tree of values to a tree of composite attribute values.
|
|
44
|
+
|
|
45
|
+
This is used for pre-processing op attributes before passing them to
|
|
46
|
+
the composite op as attributes.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
values: A tree of values.
|
|
50
|
+
stringify_incompatible_values: If True, stringify values that are not
|
|
51
|
+
compatible with composite attributes.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
A tree of composite attribute values.
|
|
55
|
+
"""
|
|
47
56
|
|
|
48
57
|
def convert(value):
|
|
49
58
|
nonlocal stringify_incompatible_values
|
|
@@ -60,6 +69,11 @@ def _tree_map_to_composite_attr_values(
|
|
|
60
69
|
|
|
61
70
|
|
|
62
71
|
class TorchOpArgumentsMapper:
|
|
72
|
+
"""A helper class to map op arguments to kwargs.
|
|
73
|
+
|
|
74
|
+
This is mainly used to extract the default values for op arguments and present
|
|
75
|
+
all arguments as kwargs.
|
|
76
|
+
"""
|
|
63
77
|
|
|
64
78
|
def __init__(self, op):
|
|
65
79
|
if isinstance(op, torch._ops.OpOverloadPacket):
|
|
@@ -72,13 +86,21 @@ class TorchOpArgumentsMapper:
|
|
|
72
86
|
]
|
|
73
87
|
|
|
74
88
|
def get_full_kwargs(self, args, kwargs=None) -> dict[str, Any]:
|
|
75
|
-
"""
|
|
76
|
-
|
|
77
|
-
|
|
89
|
+
"""Extracts all arguments of the op as kwargs.
|
|
90
|
+
|
|
91
|
+
Inspect the op's schema and extract all its args and kwargs into one single
|
|
92
|
+
kwargs dict, with default values for those unspecified args and kwargs.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
args: The op's arguments.
|
|
96
|
+
kwargs: The op's kwargs.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
A kwargs dict with all args and kwargs.
|
|
78
100
|
"""
|
|
79
101
|
full_kwargs = {**(kwargs or {})}
|
|
80
102
|
|
|
81
|
-
for arg, (name,
|
|
103
|
+
for arg, (name, _) in zip(args, self.arg_specs):
|
|
82
104
|
full_kwargs[name] = arg
|
|
83
105
|
|
|
84
106
|
for name, default_value in self.arg_specs[len(args) :]:
|
|
@@ -89,12 +111,13 @@ class TorchOpArgumentsMapper:
|
|
|
89
111
|
|
|
90
112
|
|
|
91
113
|
@_register_composite_builder(torch.ops.aten.hardswish.default)
|
|
92
|
-
def _aten_hardswish(
|
|
114
|
+
def _aten_hardswish(_: torch.fx.GraphModule, node: torch.fx.Node):
|
|
115
|
+
"""Build a composite for aten.hardswish.default."""
|
|
93
116
|
op = node.target
|
|
94
117
|
|
|
95
118
|
def hardswish(self: torch.Tensor):
|
|
96
119
|
nonlocal op
|
|
97
|
-
builder = StableHLOCompositeBuilder("aten.hardswish.default")
|
|
120
|
+
builder = lowertools.StableHLOCompositeBuilder("aten.hardswish.default")
|
|
98
121
|
self = builder.mark_inputs(self)
|
|
99
122
|
output = op(self)
|
|
100
123
|
output = builder.mark_outputs(output)
|
|
@@ -104,7 +127,8 @@ def _aten_hardswish(gm: GraphModule, node: Node):
|
|
|
104
127
|
|
|
105
128
|
|
|
106
129
|
@_register_composite_builder(torch.ops.aten.gelu.default)
|
|
107
|
-
def _aten_gelu(
|
|
130
|
+
def _aten_gelu(_: torch.fx.GraphModule, node: torch.fx.Node):
|
|
131
|
+
"""Build a composite for aten.gelu.default."""
|
|
108
132
|
op = node.target
|
|
109
133
|
args_mapper = TorchOpArgumentsMapper(op)
|
|
110
134
|
|
|
@@ -120,7 +144,7 @@ def _aten_gelu(gm: GraphModule, node: Node):
|
|
|
120
144
|
):
|
|
121
145
|
return op(*args, **kwargs)
|
|
122
146
|
|
|
123
|
-
builder = StableHLOCompositeBuilder(
|
|
147
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
|
124
148
|
"aten.gelu.default",
|
|
125
149
|
attr=_tree_map_to_composite_attr_values({
|
|
126
150
|
"approximate": full_kwargs["approximate"],
|
|
@@ -135,7 +159,8 @@ def _aten_gelu(gm: GraphModule, node: Node):
|
|
|
135
159
|
|
|
136
160
|
|
|
137
161
|
@_register_composite_builder(torch.ops.aten.avg_pool2d.default)
|
|
138
|
-
def _aten_avg_pool2d(
|
|
162
|
+
def _aten_avg_pool2d(_: torch.fx.GraphModule, node: torch.fx.Node):
|
|
163
|
+
"""Build a composite for aten.avg_pool2d.default."""
|
|
139
164
|
op = node.target
|
|
140
165
|
args_mapper = TorchOpArgumentsMapper(op)
|
|
141
166
|
|
|
@@ -199,7 +224,7 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
|
199
224
|
):
|
|
200
225
|
return op(*args, **kwargs)
|
|
201
226
|
|
|
202
|
-
builder = StableHLOCompositeBuilder(
|
|
227
|
+
builder = lowertools.StableHLOCompositeBuilder(
|
|
203
228
|
"aten.avg_pool2d.default",
|
|
204
229
|
attr=_tree_map_to_composite_attr_values({
|
|
205
230
|
"kernel_size": full_kwargs["kernel_size"],
|
|
@@ -220,7 +245,7 @@ def _aten_avg_pool2d(gm: GraphModule, node: Node):
|
|
|
220
245
|
|
|
221
246
|
|
|
222
247
|
@_register_composite_builder(torch.ops.aten.embedding.default)
|
|
223
|
-
def _aten_embedding(gm: GraphModule, node: Node):
|
|
248
|
+
def _aten_embedding(gm: torch.fx.GraphModule, node: torch.fx.Node):
|
|
224
249
|
op = node.target
|
|
225
250
|
args_mapper = TorchOpArgumentsMapper(op)
|
|
226
251
|
|
|
@@ -237,7 +262,7 @@ def _aten_embedding(gm: GraphModule, node: Node):
|
|
|
237
262
|
# Explicitly reshape to 1D. This places the ReshapeOp outside of the HLFB.
|
|
238
263
|
idx = torch.reshape(idx, (idx.numel(),))
|
|
239
264
|
|
|
240
|
-
builder = StableHLOCompositeBuilder("odml.embedding_lookup")
|
|
265
|
+
builder = lowertools.StableHLOCompositeBuilder("odml.embedding_lookup")
|
|
241
266
|
full_kwargs["indices"], full_kwargs["weight"] = builder.mark_inputs(
|
|
242
267
|
idx,
|
|
243
268
|
full_kwargs["weight"],
|
|
@@ -252,13 +277,13 @@ def _aten_embedding(gm: GraphModule, node: Node):
|
|
|
252
277
|
node.target = embedding
|
|
253
278
|
|
|
254
279
|
|
|
255
|
-
class BuildAtenCompositePass(PassBase):
|
|
280
|
+
class BuildAtenCompositePass(pass_base.PassBase):
|
|
256
281
|
|
|
257
|
-
def call(self, graph_module: GraphModule):
|
|
282
|
+
def call(self, graph_module: torch.fx.GraphModule):
|
|
258
283
|
for node in graph_module.graph.nodes:
|
|
259
284
|
if node.target in _composite_builders:
|
|
260
285
|
_composite_builders[node.target](graph_module, node)
|
|
261
286
|
|
|
262
287
|
graph_module.graph.lint()
|
|
263
288
|
graph_module.recompile()
|
|
264
|
-
return PassResult(graph_module, True)
|
|
289
|
+
return pass_base.PassResult(graph_module, True)
|