ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
- ai_edge_torch/convert/conversion_utils.py +0 -439
- ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
19
|
+
import tensorflow as tf
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _torch_to_tf_variable(torch_tensor: torch.Tensor):
|
|
24
|
+
if not torch_tensor.is_contiguous():
|
|
25
|
+
torch_tensor = torch_tensor.contiguous()
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
29
|
+
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
30
|
+
except Exception:
|
|
31
|
+
logging.info(
|
|
32
|
+
"Can not use dlpack to convert torch tensors. Falling back to numpy."
|
|
33
|
+
)
|
|
34
|
+
nparray = torch_tensor.cpu().detach().numpy()
|
|
35
|
+
tf_tensor = tf.convert_to_tensor(nparray)
|
|
36
|
+
|
|
37
|
+
return tf.Variable(tf_tensor, trainable=False)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_states(
|
|
41
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
42
|
+
signatures: list[signature_module.Signature],
|
|
43
|
+
):
|
|
44
|
+
for exported_program, signature in zip(exported_programs, signatures):
|
|
45
|
+
args, _ = exported_program.example_inputs
|
|
46
|
+
# Calling this to get **all** the state including model buffers.
|
|
47
|
+
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
|
|
48
|
+
for tensor, input_spec in zip(
|
|
49
|
+
_flat_input_args, exported_program.graph_signature.input_specs
|
|
50
|
+
):
|
|
51
|
+
# Only interested in Tensors that are part of the state (and not user input).
|
|
52
|
+
if (
|
|
53
|
+
not isinstance(tensor, torch.Tensor)
|
|
54
|
+
or input_spec.kind
|
|
55
|
+
== torch.export.graph_signature.InputKind.USER_INPUT
|
|
56
|
+
):
|
|
57
|
+
continue
|
|
58
|
+
yield signature, tensor, input_spec
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _tensor_unique_id(tensor: torch.Tensor):
|
|
62
|
+
return (
|
|
63
|
+
str(tensor.device),
|
|
64
|
+
tensor.shape,
|
|
65
|
+
tensor.stride(),
|
|
66
|
+
tensor.untyped_storage().data_ptr(),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def gather_state_dict(
|
|
71
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
72
|
+
signatures: list[signature_module.Signature],
|
|
73
|
+
):
|
|
74
|
+
deduped_tensor_map = {}
|
|
75
|
+
|
|
76
|
+
for _, tensor, _ in _get_states(exported_programs, signatures):
|
|
77
|
+
unique_id = _tensor_unique_id(tensor)
|
|
78
|
+
deduped_tensor_map[unique_id] = _torch_to_tf_variable(tensor)
|
|
79
|
+
|
|
80
|
+
state_dict = {}
|
|
81
|
+
for signature, tensor, input_spec in _get_states(
|
|
82
|
+
exported_programs, signatures
|
|
83
|
+
):
|
|
84
|
+
unique_id = _tensor_unique_id(tensor)
|
|
85
|
+
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
|
|
86
|
+
unique_id
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
return state_dict, list(deduped_tensor_map.values())
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import dataclasses
|
|
17
|
+
import tempfile
|
|
18
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
from ai_edge_torch import odml_torch
|
|
21
|
+
from ai_edge_torch._convert import conversion_utils
|
|
22
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
23
|
+
from ai_edge_torch.lowertools import common_utils
|
|
24
|
+
from ai_edge_torch.odml_torch import export
|
|
25
|
+
from ai_edge_torch.odml_torch import export_utils
|
|
26
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
27
|
+
import tensorflow as tf
|
|
28
|
+
import torch
|
|
29
|
+
|
|
30
|
+
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
31
|
+
|
|
32
|
+
MlirBundle = odml_torch.export.MlirLowered
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclasses.dataclass
|
|
36
|
+
class MergedBundle:
|
|
37
|
+
"""A bundle of MlirLowered that has been merged."""
|
|
38
|
+
|
|
39
|
+
bundles: list[odml_torch.export.MlirLowered]
|
|
40
|
+
deduped_tf_vars: list[tf.Variable]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def torch_dtype_to_tf(dtype):
|
|
44
|
+
return {
|
|
45
|
+
torch.double: tf.float64,
|
|
46
|
+
torch.float32: tf.float32,
|
|
47
|
+
torch.half: tf.float16,
|
|
48
|
+
torch.long: tf.int64,
|
|
49
|
+
torch.int32: tf.int32,
|
|
50
|
+
torch.int16: tf.int16,
|
|
51
|
+
torch.bool: tf.bool,
|
|
52
|
+
}.get(dtype)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _get_shape_with_dynamic(signature: export.VariableSignature):
|
|
56
|
+
return [
|
|
57
|
+
None if export_utils.is_torch_dynamic(s) else s for s in signature.shape
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _extract_call_args(
|
|
62
|
+
bundle: export.MlirLowered,
|
|
63
|
+
args: Tuple[Any],
|
|
64
|
+
tf_state_dict: Dict[str, tf.Variable],
|
|
65
|
+
):
|
|
66
|
+
call_args = []
|
|
67
|
+
for sig in bundle.input_signature:
|
|
68
|
+
if sig.input_spec.is_user_input:
|
|
69
|
+
call_args.append(args[sig.input_spec.i])
|
|
70
|
+
elif sig.input_spec.is_parameter:
|
|
71
|
+
name = sig.input_spec.name
|
|
72
|
+
call_args.append(tf_state_dict[name])
|
|
73
|
+
return call_args
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _wrap_as_tf_func(bundle, tf_state_dict):
|
|
77
|
+
def inner(*args):
|
|
78
|
+
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
|
|
79
|
+
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
|
|
80
|
+
call_args = _extract_call_args(bundle, args, tf_state_dict)
|
|
81
|
+
return tfxla.call_module(
|
|
82
|
+
tuple(call_args),
|
|
83
|
+
version=5,
|
|
84
|
+
Tout=t_outs, # dtype information
|
|
85
|
+
Sout=s_outs, # Shape information
|
|
86
|
+
function_list=[],
|
|
87
|
+
module=bundle.module_bytecode,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
return inner
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _make_tf_signature(
|
|
94
|
+
input_signature: list[export.VariableSignature],
|
|
95
|
+
signature: signature_module.Signature,
|
|
96
|
+
) -> List[tf.TensorSpec]:
|
|
97
|
+
input_names = signature.flat_arg_names
|
|
98
|
+
user_input_signature = sorted(
|
|
99
|
+
[sig for sig in input_signature if sig.input_spec.is_user_input],
|
|
100
|
+
key=lambda sig: sig.input_spec.i,
|
|
101
|
+
)
|
|
102
|
+
tf_signature = []
|
|
103
|
+
|
|
104
|
+
for sig in user_input_signature:
|
|
105
|
+
shape = _get_shape_with_dynamic(sig)
|
|
106
|
+
tf_signature.append(
|
|
107
|
+
tf.TensorSpec(
|
|
108
|
+
shape=shape,
|
|
109
|
+
dtype=torch_dtype_to_tf(sig.dtype),
|
|
110
|
+
name=input_names[sig.input_spec.i],
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
return tf_signature
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def merged_bundle_to_tfl_model(
|
|
117
|
+
merged_bundle: MergedBundle,
|
|
118
|
+
signatures: list[signature_module.Signature],
|
|
119
|
+
*,
|
|
120
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
121
|
+
_tfl_converter_flags: dict = {},
|
|
122
|
+
):
|
|
123
|
+
tf_state_dict = merged_bundle.bundles[0].state_dict
|
|
124
|
+
|
|
125
|
+
tf_signatures = [
|
|
126
|
+
_make_tf_signature(bundle.input_signature, sig)
|
|
127
|
+
for bundle, sig in zip(merged_bundle.bundles, signatures)
|
|
128
|
+
]
|
|
129
|
+
tf_functions = [
|
|
130
|
+
_wrap_as_tf_func(bundle, tf_state_dict)
|
|
131
|
+
for bundle in merged_bundle.bundles
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
tf_module = tf.Module()
|
|
135
|
+
tf_module.f = []
|
|
136
|
+
|
|
137
|
+
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
138
|
+
tf_module.f.append(
|
|
139
|
+
tf.function(
|
|
140
|
+
func,
|
|
141
|
+
input_signature=tf_sig,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
tf_module._variables = merged_bundle.deduped_tf_vars
|
|
146
|
+
|
|
147
|
+
tf_concrete_funcs = [
|
|
148
|
+
func.get_concrete_function(*tf_sig)
|
|
149
|
+
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
|
150
|
+
]
|
|
151
|
+
|
|
152
|
+
# We need to temporarily save since TFLite's from_concrete_functions does not
|
|
153
|
+
# allow providing names for each of the concrete functions.
|
|
154
|
+
with tempfile.TemporaryDirectory() as temp_dir_path:
|
|
155
|
+
tf.saved_model.save(
|
|
156
|
+
tf_module,
|
|
157
|
+
temp_dir_path,
|
|
158
|
+
signatures={
|
|
159
|
+
sig.name: tf_concrete_funcs[idx]
|
|
160
|
+
for idx, sig in enumerate(signatures)
|
|
161
|
+
},
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
165
|
+
converter._experimental_enable_composite_direct_lowering = True
|
|
166
|
+
|
|
167
|
+
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
|
168
|
+
|
|
169
|
+
tflite_model = converter.convert()
|
|
170
|
+
|
|
171
|
+
return tflite_model
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def exported_program_to_mlir_text(
|
|
175
|
+
exported_program: torch.export.ExportedProgram,
|
|
176
|
+
) -> str:
|
|
177
|
+
"""Converts a ExportedProgram to a MLIR text."""
|
|
178
|
+
return odml_torch.export.exported_program_to_mlir(exported_program).get_text(
|
|
179
|
+
enable_debug_info=True
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def exported_program_to_mlir(
|
|
184
|
+
exported_program: torch.export.ExportedProgram,
|
|
185
|
+
sample_args: tuple[torch.Tensor],
|
|
186
|
+
) -> export.MlirLowered:
|
|
187
|
+
"""Converts a ExportedProgram to a MlirLowered."""
|
|
188
|
+
return odml_torch.export.exported_program_to_mlir(exported_program)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def merge_mlir_bundles(
|
|
192
|
+
bundles: list[export.MlirLowered],
|
|
193
|
+
signatures: list[signature_module.Signature],
|
|
194
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
195
|
+
) -> MergedBundle:
|
|
196
|
+
"""Merges a list of MlirLowered into one."""
|
|
197
|
+
state_dict, deduped_vars = common_utils.gather_state_dict(
|
|
198
|
+
exported_programs, signatures
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
merged_bundle = MergedBundle(
|
|
202
|
+
bundles=bundles.copy(), deduped_tf_vars=deduped_vars
|
|
203
|
+
)
|
|
204
|
+
for bundle, signature in zip(merged_bundle.bundles, signatures):
|
|
205
|
+
bundle.state_dict = state_dict
|
|
206
|
+
|
|
207
|
+
for var_sig in bundle.input_signature:
|
|
208
|
+
if var_sig.input_spec.is_parameter:
|
|
209
|
+
var_sig.input_spec.name = signature.name + "_" + var_sig.input_spec.name
|
|
210
|
+
|
|
211
|
+
return merged_bundle
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import copy
|
|
17
|
+
import dataclasses
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
import gc
|
|
20
|
+
import itertools
|
|
21
|
+
import logging
|
|
22
|
+
import tempfile
|
|
23
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
24
|
+
|
|
25
|
+
from ai_edge_torch import model
|
|
26
|
+
from ai_edge_torch._convert import conversion_utils
|
|
27
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
28
|
+
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
29
|
+
from ai_edge_torch.lowertools import common_utils
|
|
30
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
31
|
+
import torch
|
|
32
|
+
from torch_xla import stablehlo
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import tensorflow as tf
|
|
36
|
+
|
|
37
|
+
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
38
|
+
|
|
39
|
+
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
|
|
40
|
+
except ImportError:
|
|
41
|
+
logging.error(
|
|
42
|
+
"This module needs tensorflow with xla support.\n"
|
|
43
|
+
"Please install tensorflow with `pip install tf-nightly`.\n"
|
|
44
|
+
)
|
|
45
|
+
raise
|
|
46
|
+
|
|
47
|
+
MlirBundle = stablehlo.StableHLOModelBundle
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclasses.dataclass
|
|
51
|
+
class MergedBundle:
|
|
52
|
+
|
|
53
|
+
bundle: stablehlo.StableHLOModelBundle
|
|
54
|
+
deduped_tf_vars: list[tf.Variable]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def exported_program_to_mlir(
|
|
58
|
+
exported_program: torch.export.ExportedProgram,
|
|
59
|
+
sample_args: tuple[torch.Tensor],
|
|
60
|
+
) -> stablehlo.StableHLOModelBundle:
|
|
61
|
+
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
|
|
62
|
+
# to a numpy array which would lead to memory bloat. This means that the state_dict
|
|
63
|
+
# in the returned bundle is going to be empty.
|
|
64
|
+
return stablehlo.exported_program_to_stablehlo(
|
|
65
|
+
exported_program,
|
|
66
|
+
stablehlo.StableHLOExportOptions(
|
|
67
|
+
override_tracing_arguments=sample_args, export_weights=False
|
|
68
|
+
),
|
|
69
|
+
)._bundle
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def merge_mlir_bundles(
|
|
73
|
+
bundles: list[stablehlo.StableHLOModelBundle],
|
|
74
|
+
signatures: list[signature_module.Signature],
|
|
75
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
76
|
+
) -> stablehlo.StableHLOGraphModule:
|
|
77
|
+
state_dict, deduped_tf_vars = common_utils.gather_state_dict(
|
|
78
|
+
exported_programs, signatures
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
new_shlo_model_bundle = stablehlo.StableHLOModelBundle(
|
|
82
|
+
state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
for bundle, signature in zip(bundles, signatures):
|
|
86
|
+
const_offset = len(new_shlo_model_bundle.additional_constants)
|
|
87
|
+
for func in bundle.stablehlo_funcs:
|
|
88
|
+
func.meta.name = signature.name + "_" + func.meta.name
|
|
89
|
+
for loc in func.meta.input_locations:
|
|
90
|
+
if loc.type_ == stablehlo.VariableType.CONSTANT:
|
|
91
|
+
loc.position += const_offset
|
|
92
|
+
elif loc.type_ == stablehlo.VariableType.PARAMETER:
|
|
93
|
+
loc.name = signature.name + "_" + loc.name
|
|
94
|
+
new_shlo_model_bundle.stablehlo_funcs.append(func)
|
|
95
|
+
new_shlo_model_bundle.additional_constants.extend(
|
|
96
|
+
bundle.additional_constants
|
|
97
|
+
)
|
|
98
|
+
return MergedBundle(
|
|
99
|
+
bundle=new_shlo_model_bundle, deduped_tf_vars=deduped_tf_vars
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
|
|
104
|
+
shape = copy.copy(signature.shape)
|
|
105
|
+
for i in signature.dynamic_dims:
|
|
106
|
+
shape[i] = None
|
|
107
|
+
return shape
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _wrap_as_tf_func(
|
|
111
|
+
func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
|
|
112
|
+
):
|
|
113
|
+
def inner(*args):
|
|
114
|
+
type_info = [sig.dtype for sig in func.meta.output_signature]
|
|
115
|
+
shape_info = [
|
|
116
|
+
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
|
|
117
|
+
]
|
|
118
|
+
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
|
|
119
|
+
return tfxla.call_module(
|
|
120
|
+
tuple(call_args),
|
|
121
|
+
version=5,
|
|
122
|
+
Tout=type_info,
|
|
123
|
+
Sout=shape_info,
|
|
124
|
+
function_list=[],
|
|
125
|
+
module=func.bytecode,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
return inner
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _make_tf_function(
|
|
132
|
+
bundle: stablehlo.StableHLOModelBundle = None,
|
|
133
|
+
):
|
|
134
|
+
bundle = bundle if bundle is None else bundle
|
|
135
|
+
return [_wrap_as_tf_func(func, bundle) for func in bundle.stablehlo_funcs]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _make_tf_signature(
|
|
139
|
+
meta: stablehlo.StableHLOFunctionMeta,
|
|
140
|
+
signature: signature_module.Signature,
|
|
141
|
+
) -> list[tf.TensorSpec]:
|
|
142
|
+
input_names = signature.flat_arg_names
|
|
143
|
+
input_pos_to_spec = {
|
|
144
|
+
loc.position: spec
|
|
145
|
+
for loc, spec in itertools.chain(
|
|
146
|
+
zip(meta.input_locations, meta.input_signature), meta.unused_inputs
|
|
147
|
+
)
|
|
148
|
+
if loc.type_ == stablehlo.VariableType.INPUT_ARG
|
|
149
|
+
}
|
|
150
|
+
assert len(input_pos_to_spec) == len(input_names)
|
|
151
|
+
|
|
152
|
+
primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
|
|
153
|
+
ret: list[tf.TensorSpec] = []
|
|
154
|
+
for i, name in enumerate(input_names):
|
|
155
|
+
spec = input_pos_to_spec[i]
|
|
156
|
+
shape = _get_shape_with_dynamic(spec)
|
|
157
|
+
ret.append(
|
|
158
|
+
tf.TensorSpec(
|
|
159
|
+
shape=shape,
|
|
160
|
+
dtype=primitive_type_to_tf_type[spec.dtype]
|
|
161
|
+
if spec.dtype in primitive_type_to_tf_type
|
|
162
|
+
else spec.dtype,
|
|
163
|
+
name=name,
|
|
164
|
+
)
|
|
165
|
+
)
|
|
166
|
+
return ret
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def exported_program_to_mlir_text(
|
|
170
|
+
exported_program: torch.export.ExportedProgram,
|
|
171
|
+
) -> str:
|
|
172
|
+
"""Converts a ExportedProgram to a MLIR text."""
|
|
173
|
+
return stablehlo.exported_program_to_stablehlo(
|
|
174
|
+
exported_program
|
|
175
|
+
).get_stablehlo_text()
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def merged_bundle_to_tfl_model(
|
|
179
|
+
merged_bundle: MergedBundle,
|
|
180
|
+
signatures: list[signature_module.Signature],
|
|
181
|
+
*,
|
|
182
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
183
|
+
_tfl_converter_flags: dict = {},
|
|
184
|
+
) -> None:
|
|
185
|
+
"""Converts a StableHLOGraphModule to a tflite model.
|
|
186
|
+
|
|
187
|
+
Args: shlo_bundle - model to export and save
|
|
188
|
+
|
|
189
|
+
signatures: List of signatures from which names of the signatures is
|
|
190
|
+
extracted.
|
|
191
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
192
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
|
193
|
+
underlying tflite converter.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
tf_module = tf.Module()
|
|
197
|
+
|
|
198
|
+
shlo_bundle = merged_bundle.bundle
|
|
199
|
+
|
|
200
|
+
shlo_bundle.additional_constants = [
|
|
201
|
+
tf.Variable(v, trainable=False) for v in shlo_bundle.additional_constants
|
|
202
|
+
]
|
|
203
|
+
tf_signatures: list[list[tf.TensorSpec]] = list(
|
|
204
|
+
_make_tf_signature(func.meta, sig)
|
|
205
|
+
for func, sig in zip(shlo_bundle.stablehlo_funcs, signatures)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
tf_functions = _make_tf_function(shlo_bundle)
|
|
209
|
+
|
|
210
|
+
tf_module.f = []
|
|
211
|
+
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
212
|
+
tf_module.f.append(
|
|
213
|
+
tf.function(
|
|
214
|
+
func,
|
|
215
|
+
input_signature=tf_sig,
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
tf_module._variables = (
|
|
220
|
+
merged_bundle.deduped_tf_vars + shlo_bundle.additional_constants
|
|
221
|
+
)
|
|
222
|
+
del shlo_bundle
|
|
223
|
+
gc.collect()
|
|
224
|
+
|
|
225
|
+
tf_concrete_funcs = [
|
|
226
|
+
func.get_concrete_function(*tf_sig)
|
|
227
|
+
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
# We need to temporarily save since TFLite's from_concrete_functions does not
|
|
231
|
+
# allow providing names for each of the concrete functions.
|
|
232
|
+
with tempfile.TemporaryDirectory() as temp_dir_path:
|
|
233
|
+
tf.saved_model.save(
|
|
234
|
+
tf_module,
|
|
235
|
+
temp_dir_path,
|
|
236
|
+
signatures={
|
|
237
|
+
sig.name: tf_concrete_funcs[idx]
|
|
238
|
+
for idx, sig in enumerate(signatures)
|
|
239
|
+
},
|
|
240
|
+
)
|
|
241
|
+
# Clean up intermediate memory early.
|
|
242
|
+
del tf_module
|
|
243
|
+
del tf_concrete_funcs
|
|
244
|
+
gc.collect()
|
|
245
|
+
|
|
246
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
247
|
+
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
|
248
|
+
converter._experimental_enable_composite_direct_lowering = True
|
|
249
|
+
|
|
250
|
+
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
|
251
|
+
if (
|
|
252
|
+
quant_config is not None
|
|
253
|
+
and quant_config._quantizer_mode
|
|
254
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
255
|
+
):
|
|
256
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
|
257
|
+
quant_config.generative_recipe
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
|
261
|
+
|
|
262
|
+
tflite_model = converter.convert()
|
|
263
|
+
|
|
264
|
+
if (
|
|
265
|
+
quant_config is not None
|
|
266
|
+
and quant_config._quantizer_mode
|
|
267
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
268
|
+
):
|
|
269
|
+
tflite_model = translate_recipe.quantize_model(
|
|
270
|
+
tflite_model, translated_recipe
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
return tflite_model
|
ai_edge_torch/model.py
CHANGED
|
@@ -15,17 +15,18 @@
|
|
|
15
15
|
|
|
16
16
|
"""Represents an ai_edge_torch model.
|
|
17
17
|
|
|
18
|
-
PyTorch models can be converted to this representation through
|
|
18
|
+
PyTorch models can be converted to this representation through
|
|
19
|
+
`ai_edge_torch.convert`.
|
|
19
20
|
"""
|
|
20
21
|
from __future__ import annotations
|
|
21
22
|
|
|
22
23
|
import abc
|
|
23
24
|
|
|
24
|
-
from ai_edge_torch.convert import conversion_utils as cutils
|
|
25
|
-
import numpy as np
|
|
26
25
|
import numpy.typing as npt
|
|
27
26
|
import tensorflow as tf
|
|
28
27
|
|
|
28
|
+
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
29
|
+
|
|
29
30
|
|
|
30
31
|
class Model(abc.ABC):
|
|
31
32
|
"""Represents and edge model."""
|
|
@@ -34,7 +35,7 @@ class Model(abc.ABC):
|
|
|
34
35
|
def __call__(
|
|
35
36
|
self,
|
|
36
37
|
*args: npt.ArrayLike,
|
|
37
|
-
signature_name: str =
|
|
38
|
+
signature_name: str = DEFAULT_SIGNATURE_NAME,
|
|
38
39
|
**kwargs,
|
|
39
40
|
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
40
41
|
raise NotImplementedError()
|
|
@@ -66,18 +67,22 @@ class TfLiteModel(Model):
|
|
|
66
67
|
def __call__(
|
|
67
68
|
self,
|
|
68
69
|
*args: npt.ArrayLike,
|
|
69
|
-
signature_name: str =
|
|
70
|
+
signature_name: str = DEFAULT_SIGNATURE_NAME,
|
|
70
71
|
**kwargs,
|
|
71
72
|
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
72
73
|
"""Runs inference on the edge model using the provided arguments.
|
|
73
74
|
|
|
74
75
|
Args:
|
|
75
76
|
*args: The arguments to be passed to the model for inference.
|
|
76
|
-
**kwargs: The arguments with specific names to be passed to the model for
|
|
77
|
-
|
|
78
|
-
|
|
77
|
+
**kwargs: The arguments with specific names to be passed to the model for
|
|
78
|
+
inference.
|
|
79
|
+
signature_name: The name of the signature to be used for inference. The
|
|
80
|
+
default signature is used if not provided.
|
|
79
81
|
"""
|
|
80
|
-
interpreter = tf.lite.Interpreter(
|
|
82
|
+
interpreter = tf.lite.Interpreter(
|
|
83
|
+
model_content=self._tflite_model,
|
|
84
|
+
experimental_default_delegate_latest_features=True,
|
|
85
|
+
)
|
|
81
86
|
interpreter.allocate_tensors()
|
|
82
87
|
|
|
83
88
|
signature_list = interpreter.get_signature_list()
|
|
@@ -188,15 +188,18 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]:
|
|
|
188
188
|
|
|
189
189
|
def _get_module_name_filter(module_name: str):
|
|
190
190
|
"""Get the module_name_filter function for a given module name, the filter accepts
|
|
191
|
+
|
|
191
192
|
a node and checks if the node comes from a module that has certain module name
|
|
192
193
|
|
|
193
194
|
For example:
|
|
194
|
-
node: linear_op = call_function[...](...) # comes from a module with name
|
|
195
|
+
node: linear_op = call_function[...](...) # comes from a module with name
|
|
196
|
+
blocks.sub.linear1
|
|
195
197
|
|
|
196
198
|
|
|
197
199
|
>> module_name_filter = _get_module_name_filter("blocks.sub")
|
|
198
200
|
>> print(module_name_filter(node))
|
|
199
|
-
True # the node is from "blocks.sub" based on the fully qualified name
|
|
201
|
+
True # the node is from "blocks.sub" based on the fully qualified name
|
|
202
|
+
"blocks.sub.linear1"
|
|
200
203
|
"""
|
|
201
204
|
|
|
202
205
|
def module_name_filter(n: Node) -> bool:
|
|
@@ -216,15 +219,19 @@ def _get_module_name_filter(module_name: str):
|
|
|
216
219
|
|
|
217
220
|
def _get_module_type_filter(tp: Callable):
|
|
218
221
|
"""Get the module_type_filter function for a given module type, the filter accepts
|
|
222
|
+
|
|
219
223
|
a node and checks if the node comes from a module that has certain module type
|
|
220
224
|
|
|
221
225
|
For example:
|
|
222
|
-
node: linear_op = call_function[...](...) # comes from a module with type
|
|
226
|
+
node: linear_op = call_function[...](...) # comes from a module with type
|
|
227
|
+
Block -> Sub -> Linear
|
|
223
228
|
|
|
224
229
|
|
|
225
|
-
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type
|
|
230
|
+
>> module_type_filter = _get_module_type_filter(Sub) # submodule with type
|
|
231
|
+
`Sub`, under the `Block` submodule
|
|
226
232
|
>> print(module_type_filter(node))
|
|
227
|
-
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as
|
|
233
|
+
True # the node is from the submodule `Sub` (same for `Block` and `Linear` as
|
|
234
|
+
well)
|
|
228
235
|
"""
|
|
229
236
|
|
|
230
237
|
def module_type_filter(n: Node) -> bool:
|
|
@@ -338,8 +345,11 @@ class PT2EQuantizer(Quantizer):
|
|
|
338
345
|
self, module_type: Callable, quantization_config: QuantizationConfig
|
|
339
346
|
):
|
|
340
347
|
"""Set quantization_config for a submodule with type: `module_type`, for example:
|
|
341
|
-
|
|
342
|
-
|
|
348
|
+
|
|
349
|
+
quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it
|
|
350
|
+
will quantize all supported operator/operator
|
|
351
|
+
patterns in the submodule with this module type with the given
|
|
352
|
+
`quantization_config`
|
|
343
353
|
"""
|
|
344
354
|
self.module_type_config[module_type] = quantization_config
|
|
345
355
|
return self
|
|
@@ -348,8 +358,11 @@ class PT2EQuantizer(Quantizer):
|
|
|
348
358
|
self, module_name: str, quantization_config: Optional[QuantizationConfig]
|
|
349
359
|
):
|
|
350
360
|
"""Set quantization_config for a submodule with name: `module_name`, for example:
|
|
351
|
-
|
|
352
|
-
|
|
361
|
+
|
|
362
|
+
quantizer.set_module_name("blocks.sub"), it will quantize all supported
|
|
363
|
+
operator/operator
|
|
364
|
+
patterns in the submodule with this module name with the given
|
|
365
|
+
`quantization_config`
|
|
353
366
|
"""
|
|
354
367
|
assert (
|
|
355
368
|
quantization_config is not None
|