ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +30 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +330 -0
- ai_edge_torch/convert/converter.py +171 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
- ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +273 -0
- ai_edge_torch/convert/test/test_convert_composites.py +171 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/debug/__init__.py +16 -0
- ai_edge_torch/debug/culprit.py +423 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +288 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +103 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +135 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/test_model_conversion.py +201 -0
- ai_edge_torch/generative/test/test_quantize.py +109 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +290 -0
- ai_edge_torch/generative/utilities/t5_loader.py +467 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +134 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +85 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from .convert.converter import convert
|
|
17
|
+
from .convert.converter import signature
|
|
18
|
+
from .model import Model
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load(path: str) -> Model:
|
|
22
|
+
"""Imports an ai_edge_torch model from disk.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
path: The path to the serialized ai_edge_torch model.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
An ai_edge_torch.model.Model object.
|
|
29
|
+
"""
|
|
30
|
+
return Model.load(path)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import gc
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch.export import ExportedProgram
|
|
23
|
+
from torch_xla import stablehlo
|
|
24
|
+
|
|
25
|
+
from ai_edge_torch import model
|
|
26
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
27
|
+
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
|
|
28
|
+
from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
|
|
29
|
+
from ai_edge_torch.convert.fx_passes import CanonicalizePass
|
|
30
|
+
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
|
|
31
|
+
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
|
|
32
|
+
from ai_edge_torch.convert.fx_passes import run_passes
|
|
33
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
34
|
+
|
|
35
|
+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _run_convert_passes(
|
|
39
|
+
exported_program: ExportedProgram,
|
|
40
|
+
) -> ExportedProgram:
|
|
41
|
+
return run_passes(
|
|
42
|
+
exported_program,
|
|
43
|
+
[
|
|
44
|
+
BuildUpsampleBilinear2DCompositePass(),
|
|
45
|
+
CanonicalizePass(),
|
|
46
|
+
OptimizeLayoutTransposesPass(),
|
|
47
|
+
CanonicalizePass(),
|
|
48
|
+
BuildAtenCompositePass(),
|
|
49
|
+
CanonicalizePass(),
|
|
50
|
+
InjectMlirDebuginfoPass(),
|
|
51
|
+
CanonicalizePass(),
|
|
52
|
+
],
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _warn_training_modules(signatures: list[cutils.Signature]):
|
|
57
|
+
for sig in signatures:
|
|
58
|
+
if not sig.module.training:
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
message = (
|
|
62
|
+
"Your model {sig_name}is converted in training mode. "
|
|
63
|
+
"Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
|
|
64
|
+
)
|
|
65
|
+
if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
|
|
66
|
+
# User does not specify any signature names explicitly.
|
|
67
|
+
message = message.format(sig_name="")
|
|
68
|
+
else:
|
|
69
|
+
message = message.format(sig_name=f'"{sig.name}" ')
|
|
70
|
+
|
|
71
|
+
logging.warn(message)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def convert_signatures(
|
|
75
|
+
signatures: list[cutils.Signature],
|
|
76
|
+
*,
|
|
77
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
78
|
+
_tfl_converter_flags: dict = {},
|
|
79
|
+
) -> model.TfLiteModel:
|
|
80
|
+
"""Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
|
|
81
|
+
Args:
|
|
82
|
+
signatures: The list of 'Signature' objects containing PyTorch modules to be converted.
|
|
83
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
84
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
|
|
85
|
+
"""
|
|
86
|
+
_warn_training_modules(signatures)
|
|
87
|
+
|
|
88
|
+
exported_programs: torch.export.ExportedProgram = [
|
|
89
|
+
torch.export.export(
|
|
90
|
+
sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
|
|
91
|
+
)
|
|
92
|
+
for sig in signatures
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Apply default fx passes
|
|
96
|
+
exported_programs = list(map(_run_convert_passes, exported_programs))
|
|
97
|
+
shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
|
|
98
|
+
cutils.exported_program_to_stablehlo_bundle(exported, sig.sample_args)
|
|
99
|
+
for exported, sig in zip(exported_programs, signatures)
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
|
|
103
|
+
cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
|
|
104
|
+
)
|
|
105
|
+
del exported_programs
|
|
106
|
+
del shlo_bundles
|
|
107
|
+
|
|
108
|
+
gc.collect()
|
|
109
|
+
|
|
110
|
+
tflite_model = cutils.convert_stablehlo_to_tflite(
|
|
111
|
+
merged_shlo_graph_module,
|
|
112
|
+
signatures,
|
|
113
|
+
quant_config=quant_config,
|
|
114
|
+
_tfl_converter_flags=_tfl_converter_flags,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return model.TfLiteModel(tflite_model)
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import copy
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
import gc
|
|
19
|
+
import itertools
|
|
20
|
+
import logging
|
|
21
|
+
import tempfile
|
|
22
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
from torch_xla import stablehlo
|
|
26
|
+
|
|
27
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
28
|
+
|
|
29
|
+
try:
|
|
30
|
+
import tensorflow as tf
|
|
31
|
+
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
32
|
+
|
|
33
|
+
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
|
|
34
|
+
except ImportError:
|
|
35
|
+
logging.error(
|
|
36
|
+
"This module needs tensorflow with xla support.\n"
|
|
37
|
+
"Please install tensorflow with `pip install tf-nightly`.\n"
|
|
38
|
+
)
|
|
39
|
+
raise
|
|
40
|
+
|
|
41
|
+
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class Signature:
|
|
46
|
+
name: str
|
|
47
|
+
module: torch.nn.Module
|
|
48
|
+
sample_args: tuple[torch.Tensor]
|
|
49
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def exported_program_to_stablehlo_bundle(
|
|
53
|
+
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
|
|
54
|
+
) -> stablehlo.StableHLOModelBundle:
|
|
55
|
+
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
|
|
56
|
+
# to a numpy array which would lead to memory bloat. This means that the state_dict
|
|
57
|
+
# in the returned bundle is going to be empty.
|
|
58
|
+
return stablehlo.exported_program_to_stablehlo(
|
|
59
|
+
exported_program,
|
|
60
|
+
stablehlo.StableHLOExportOptions(
|
|
61
|
+
override_tracing_arguments=sample_args, export_weights=False
|
|
62
|
+
),
|
|
63
|
+
)._bundle
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
|
|
67
|
+
if not torch_tensor.is_contiguous():
|
|
68
|
+
torch_tensor = torch_tensor.contiguous()
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
72
|
+
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
73
|
+
except Exception:
|
|
74
|
+
logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
|
|
75
|
+
nparray = torch_tensor.cpu().detach().numpy()
|
|
76
|
+
tf_tensor = tf.convert_to_tensor(nparray)
|
|
77
|
+
|
|
78
|
+
return tf_tensor
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _get_states(
|
|
82
|
+
exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
|
|
83
|
+
):
|
|
84
|
+
for exported_program, signature in zip(exported_programs, signatures):
|
|
85
|
+
args, _ = exported_program.example_inputs
|
|
86
|
+
# Calling this to get **all** the state including model buffers.
|
|
87
|
+
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
|
|
88
|
+
for tensor, input_spec in zip(
|
|
89
|
+
_flat_input_args, exported_program.graph_signature.input_specs
|
|
90
|
+
):
|
|
91
|
+
# Only interested in Tensors that are part of the state (and not user input).
|
|
92
|
+
if (
|
|
93
|
+
not isinstance(tensor, torch.Tensor)
|
|
94
|
+
or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
95
|
+
):
|
|
96
|
+
continue
|
|
97
|
+
yield signature, tensor, input_spec
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _tensor_unique_id(tensor: torch.Tensor):
|
|
101
|
+
return (
|
|
102
|
+
str(tensor.device),
|
|
103
|
+
tensor.shape,
|
|
104
|
+
tensor.stride(),
|
|
105
|
+
tensor.untyped_storage().data_ptr(),
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _gather_state_dict(
|
|
110
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
111
|
+
signatures: list[Signature],
|
|
112
|
+
):
|
|
113
|
+
deduped_tensor_map = {}
|
|
114
|
+
|
|
115
|
+
for _, tensor, _ in _get_states(exported_programs, signatures):
|
|
116
|
+
unique_id = _tensor_unique_id(tensor)
|
|
117
|
+
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
|
|
118
|
+
|
|
119
|
+
state_dict = {}
|
|
120
|
+
for signature, tensor, input_spec in _get_states(exported_programs, signatures):
|
|
121
|
+
unique_id = _tensor_unique_id(tensor)
|
|
122
|
+
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
|
|
123
|
+
|
|
124
|
+
return state_dict
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def merge_stablehlo_bundles(
|
|
128
|
+
bundles: list[stablehlo.StableHLOModelBundle],
|
|
129
|
+
signatures: list[Signature],
|
|
130
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
131
|
+
) -> stablehlo.StableHLOGraphModule:
|
|
132
|
+
state_dict = _gather_state_dict(exported_programs, signatures)
|
|
133
|
+
|
|
134
|
+
new_bundle = stablehlo.StableHLOModelBundle(
|
|
135
|
+
state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
for bundle, signature in zip(bundles, signatures):
|
|
139
|
+
const_offset = len(new_bundle.additional_constants)
|
|
140
|
+
for func in bundle.stablehlo_funcs:
|
|
141
|
+
func.meta.name = signature.name + "_" + func.meta.name
|
|
142
|
+
for loc in func.meta.input_locations:
|
|
143
|
+
if loc.type_ == stablehlo.VariableType.CONSTANT:
|
|
144
|
+
loc.position += const_offset
|
|
145
|
+
elif loc.type_ == stablehlo.VariableType.PARAMETER:
|
|
146
|
+
loc.name = signature.name + "_" + loc.name
|
|
147
|
+
new_bundle.stablehlo_funcs.append(func)
|
|
148
|
+
new_bundle.additional_constants.extend(bundle.additional_constants)
|
|
149
|
+
return stablehlo.StableHLOGraphModule(new_bundle)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
|
|
153
|
+
shape = copy.copy(signature.shape)
|
|
154
|
+
for i in signature.dynamic_dims:
|
|
155
|
+
shape[i] = None
|
|
156
|
+
return shape
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _wrap_as_tf_func(
|
|
160
|
+
func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
|
|
161
|
+
):
|
|
162
|
+
def inner(*args):
|
|
163
|
+
type_info = [sig.dtype for sig in func.meta.output_signature]
|
|
164
|
+
shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
|
|
165
|
+
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
|
|
166
|
+
return tfxla.call_module(
|
|
167
|
+
tuple(call_args),
|
|
168
|
+
version=5,
|
|
169
|
+
Tout=type_info,
|
|
170
|
+
Sout=shape_info,
|
|
171
|
+
function_list=[],
|
|
172
|
+
module=func.bytecode,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return inner
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _make_tf_function(
|
|
179
|
+
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
180
|
+
bundle: stablehlo.StableHLOModelBundle = None,
|
|
181
|
+
):
|
|
182
|
+
bundle = shlo_graph_module._bundle if bundle is None else bundle
|
|
183
|
+
return [
|
|
184
|
+
_wrap_as_tf_func(func, bundle)
|
|
185
|
+
for func in shlo_graph_module._bundle.stablehlo_funcs
|
|
186
|
+
]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _make_tf_signature(
|
|
190
|
+
meta: stablehlo.StableHLOFunctionMeta,
|
|
191
|
+
) -> list[tf.TensorSpec]:
|
|
192
|
+
input_pos_to_spec = {
|
|
193
|
+
loc.position: spec
|
|
194
|
+
for loc, spec in itertools.chain(
|
|
195
|
+
zip(meta.input_locations, meta.input_signature), meta.unused_inputs
|
|
196
|
+
)
|
|
197
|
+
if loc.type_ == stablehlo.VariableType.INPUT_ARG
|
|
198
|
+
}
|
|
199
|
+
primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
|
|
200
|
+
ret: list[tf.TensorSpec] = []
|
|
201
|
+
for i in range(len(input_pos_to_spec)):
|
|
202
|
+
spec = input_pos_to_spec[i]
|
|
203
|
+
shape = _get_shape_with_dynamic(spec)
|
|
204
|
+
ret.append(
|
|
205
|
+
tf.TensorSpec(
|
|
206
|
+
shape=shape,
|
|
207
|
+
dtype=primitive_type_to_tf_type[spec.dtype]
|
|
208
|
+
if spec.dtype in primitive_type_to_tf_type
|
|
209
|
+
else spec.dtype,
|
|
210
|
+
name=f"args_{i}",
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
return ret
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _apply_tfl_backdoor_flags(
|
|
217
|
+
converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict
|
|
218
|
+
):
|
|
219
|
+
def _set_converter_flag(path: list):
|
|
220
|
+
if len(path) < 2:
|
|
221
|
+
raise ValueError("Expecting at least two values in the path.")
|
|
222
|
+
|
|
223
|
+
target_obj = converter
|
|
224
|
+
for idx in range(len(path) - 2):
|
|
225
|
+
target_obj = getattr(target_obj, path[idx])
|
|
226
|
+
|
|
227
|
+
setattr(target_obj, path[-2], path[-1])
|
|
228
|
+
|
|
229
|
+
def _iterate_dict_tree(flags_dict: dict, path: list):
|
|
230
|
+
for key, value in flags_dict.items():
|
|
231
|
+
path.append(key)
|
|
232
|
+
if isinstance(value, dict):
|
|
233
|
+
_iterate_dict_tree(value, path)
|
|
234
|
+
else:
|
|
235
|
+
path.append(value)
|
|
236
|
+
_set_converter_flag(path)
|
|
237
|
+
path.pop()
|
|
238
|
+
path.pop()
|
|
239
|
+
|
|
240
|
+
_iterate_dict_tree(tfl_converter_flags, [])
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _set_tfl_converter_quant_flags(
|
|
244
|
+
converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
|
|
245
|
+
):
|
|
246
|
+
if quant_config is not None:
|
|
247
|
+
quantizer_mode = quant_config._quantizer_mode
|
|
248
|
+
if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
|
|
249
|
+
converter._experimental_qdq_conversion_mode = "DYNAMIC"
|
|
250
|
+
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
|
|
251
|
+
converter._experimental_qdq_conversion_mode = "STATIC"
|
|
252
|
+
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_DYNAMIC:
|
|
253
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
254
|
+
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_FP16:
|
|
255
|
+
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
|
256
|
+
converter.target_spec.supported_types = [tf.float16]
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def convert_stablehlo_to_tflite(
|
|
260
|
+
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
261
|
+
signatures: list[Signature],
|
|
262
|
+
*,
|
|
263
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
264
|
+
_tfl_converter_flags: dict = {},
|
|
265
|
+
) -> None:
|
|
266
|
+
"""Converts a StableHLOGraphModule to a tflite model.
|
|
267
|
+
Args:
|
|
268
|
+
shlo_graph_module - model to export and save
|
|
269
|
+
signatures: List of signatures from which names of the signatures is extracted.
|
|
270
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
271
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
bundle = shlo_graph_module._bundle
|
|
275
|
+
tf_module = tf.Module()
|
|
276
|
+
bundle.state_dict = {
|
|
277
|
+
k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
|
|
278
|
+
}
|
|
279
|
+
bundle.additional_constants = [
|
|
280
|
+
tf.Variable(v, trainable=False) for v in bundle.additional_constants
|
|
281
|
+
]
|
|
282
|
+
tf_signatures: list[list[tf.TensorSpec]] = list(
|
|
283
|
+
_make_tf_signature(func.meta) for func in bundle.stablehlo_funcs
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
tf_functions = _make_tf_function(shlo_graph_module, bundle)
|
|
287
|
+
|
|
288
|
+
tf_module.f = []
|
|
289
|
+
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
290
|
+
tf_module.f.append(
|
|
291
|
+
tf.function(
|
|
292
|
+
func,
|
|
293
|
+
input_signature=tf_sig,
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
|
|
298
|
+
del bundle
|
|
299
|
+
gc.collect()
|
|
300
|
+
|
|
301
|
+
tf_concrete_funcs = [
|
|
302
|
+
func.get_concrete_function(*tf_sig)
|
|
303
|
+
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
|
304
|
+
]
|
|
305
|
+
|
|
306
|
+
# We need to temporarily save since TFLite's from_concrete_functions does not
|
|
307
|
+
# allow providing names for each of the concrete functions.
|
|
308
|
+
with tempfile.TemporaryDirectory() as temp_dir_path:
|
|
309
|
+
tf.saved_model.save(
|
|
310
|
+
tf_module,
|
|
311
|
+
temp_dir_path,
|
|
312
|
+
signatures={
|
|
313
|
+
sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
|
|
314
|
+
},
|
|
315
|
+
)
|
|
316
|
+
# Clean up intermediate memory early.
|
|
317
|
+
del tf_module
|
|
318
|
+
del tf_concrete_funcs
|
|
319
|
+
gc.collect()
|
|
320
|
+
|
|
321
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
322
|
+
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
|
323
|
+
converter._experimental_enable_composite_direct_lowering = True
|
|
324
|
+
|
|
325
|
+
_set_tfl_converter_quant_flags(converter, quant_config)
|
|
326
|
+
_apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
|
|
327
|
+
|
|
328
|
+
tflite_model = converter.convert()
|
|
329
|
+
|
|
330
|
+
return tflite_model
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from ai_edge_torch import model
|
|
23
|
+
from ai_edge_torch.convert import conversion
|
|
24
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
25
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Converter:
|
|
29
|
+
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self._signatures: list[cutils.Signature] = []
|
|
32
|
+
|
|
33
|
+
def signature(
|
|
34
|
+
self,
|
|
35
|
+
name: str,
|
|
36
|
+
module: torch.nn.Module,
|
|
37
|
+
sample_args: tuple[cutils.TracingArg],
|
|
38
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
39
|
+
) -> Converter:
|
|
40
|
+
"""Alias to `add_signature`"""
|
|
41
|
+
return self.add_signature(name, module, sample_args, dynamic_shapes)
|
|
42
|
+
|
|
43
|
+
def add_signature(
|
|
44
|
+
self,
|
|
45
|
+
name: str,
|
|
46
|
+
module: torch.nn.Module,
|
|
47
|
+
sample_args: tuple[cutils.TracingArg],
|
|
48
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
49
|
+
) -> Converter:
|
|
50
|
+
"""Allows adding a new named torch model along with sample args to the conversion.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
name: The name of the signature included in the converted edge model.
|
|
54
|
+
module: The torch module to be converted.
|
|
55
|
+
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
|
|
56
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
57
|
+
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
58
|
+
|
|
59
|
+
Raises:
|
|
60
|
+
ValueError: If a signature with the provided name already exists.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
if name in [sig.name for sig in self._signatures]:
|
|
64
|
+
raise ValueError(f"A signature with the provided name ({name}) is already added.")
|
|
65
|
+
|
|
66
|
+
self._signatures.append(cutils.Signature(name, module, sample_args, dynamic_shapes))
|
|
67
|
+
return self
|
|
68
|
+
|
|
69
|
+
def convert(
|
|
70
|
+
self,
|
|
71
|
+
module: torch.nn.Module = None,
|
|
72
|
+
sample_args: tuple[cutils.TracingArg] = None,
|
|
73
|
+
*,
|
|
74
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
75
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
76
|
+
_ai_edge_converter_flags: dict = {},
|
|
77
|
+
) -> model.TfLiteModel:
|
|
78
|
+
"""Finalizes the conversion and produces an edge model.
|
|
79
|
+
|
|
80
|
+
This could be called with no arguments as follows:
|
|
81
|
+
|
|
82
|
+
edge_model = Converter().signature(name, module, args).convert()
|
|
83
|
+
|
|
84
|
+
Or it could be used to set the default signature for the converted edge model:
|
|
85
|
+
|
|
86
|
+
edge_model = Converter().convert(module, args)
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
name: The name of the signature included in the converted edge model.
|
|
90
|
+
module: The torch module to be converted.
|
|
91
|
+
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
|
|
92
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
93
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
94
|
+
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
95
|
+
_ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
|
|
96
|
+
This gives access to an implementation detail of this function and so needs to be treated as such.
|
|
97
|
+
Please do not rely on this parameter except for local debugging as this can be removed in a future release.
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
|
|
101
|
+
"""
|
|
102
|
+
if module is not None:
|
|
103
|
+
if sample_args is not None: # both module and args provided
|
|
104
|
+
self.add_signature(
|
|
105
|
+
cutils.DEFAULT_SIGNATURE_NAME, module, sample_args, dynamic_shapes
|
|
106
|
+
)
|
|
107
|
+
else: # module is provided but not sample_args
|
|
108
|
+
raise ValueError("sample_args needs to be provided if a module is specified.")
|
|
109
|
+
|
|
110
|
+
return conversion.convert_signatures(
|
|
111
|
+
self._signatures,
|
|
112
|
+
quant_config=quant_config,
|
|
113
|
+
_tfl_converter_flags=_ai_edge_converter_flags,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def signature(
|
|
118
|
+
name: str,
|
|
119
|
+
module: torch.nn.Module,
|
|
120
|
+
sample_args: tuple[cutils.TracingArg],
|
|
121
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
122
|
+
) -> Converter:
|
|
123
|
+
"""Initiates a Converter object with the provided signature.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
name: The name of the signature included in the converted edge model.
|
|
127
|
+
module: The torch module to be converted.
|
|
128
|
+
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
|
|
129
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
130
|
+
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
131
|
+
|
|
132
|
+
Example:
|
|
133
|
+
converter = ai_edge_torch.signature(name, module, args)
|
|
134
|
+
edge_model = converter.convert()
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
return Converter().signature(name, module, sample_args, dynamic_shapes)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def convert(
|
|
141
|
+
module: torch.nn.Module = None,
|
|
142
|
+
sample_args: tuple[cutils.TracingArg] = None,
|
|
143
|
+
*,
|
|
144
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
145
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
146
|
+
_ai_edge_converter_flags: dict = {},
|
|
147
|
+
) -> model.TfLiteModel:
|
|
148
|
+
"""Allows converting a PyTorch model to an edge model with one default signature in one step.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
module: The torch module to be converted.
|
|
152
|
+
sample_args: Tuple of args by which the torch module will be traced prior to conversion.
|
|
153
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
154
|
+
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
155
|
+
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
156
|
+
_ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
|
|
157
|
+
This gives access to an implementation detail of this function and so needs to be treated as such.
|
|
158
|
+
Please do not rely on this parameter except for local debugging as this can be removed in a future release.
|
|
159
|
+
|
|
160
|
+
Example:
|
|
161
|
+
edge_model = ai_edge_torch.convert(module, args)
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
return Converter().convert(
|
|
166
|
+
module,
|
|
167
|
+
sample_args,
|
|
168
|
+
quant_config=quant_config,
|
|
169
|
+
dynamic_shapes=dynamic_shapes,
|
|
170
|
+
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
|
171
|
+
)
|