ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
- ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
- ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
|
|
18
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
19
|
+
import tensorflow as tf
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
|
|
24
|
+
if not torch_tensor.is_contiguous():
|
|
25
|
+
torch_tensor = torch_tensor.contiguous()
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
29
|
+
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
30
|
+
except Exception:
|
|
31
|
+
logging.info(
|
|
32
|
+
"Can not use dlpack to convert torch tensors. Falling back to numpy."
|
|
33
|
+
)
|
|
34
|
+
nparray = torch_tensor.cpu().detach().numpy()
|
|
35
|
+
tf_tensor = tf.convert_to_tensor(nparray)
|
|
36
|
+
|
|
37
|
+
return tf_tensor
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_states(
|
|
41
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
42
|
+
signatures: list[signature_module.Signature],
|
|
43
|
+
):
|
|
44
|
+
for exported_program, signature in zip(exported_programs, signatures):
|
|
45
|
+
args, _ = exported_program.example_inputs
|
|
46
|
+
# Calling this to get **all** the state including model buffers.
|
|
47
|
+
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
|
|
48
|
+
for tensor, input_spec in zip(
|
|
49
|
+
_flat_input_args, exported_program.graph_signature.input_specs
|
|
50
|
+
):
|
|
51
|
+
# Only interested in Tensors that are part of the state (and not user input).
|
|
52
|
+
if (
|
|
53
|
+
not isinstance(tensor, torch.Tensor)
|
|
54
|
+
or input_spec.kind
|
|
55
|
+
== torch.export.graph_signature.InputKind.USER_INPUT
|
|
56
|
+
):
|
|
57
|
+
continue
|
|
58
|
+
yield signature, tensor, input_spec
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _tensor_unique_id(tensor: torch.Tensor):
|
|
62
|
+
return (
|
|
63
|
+
str(tensor.device),
|
|
64
|
+
tensor.shape,
|
|
65
|
+
tensor.stride(),
|
|
66
|
+
tensor.untyped_storage().data_ptr(),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def gather_state_dict(
|
|
71
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
72
|
+
signatures: list[signature_module.Signature],
|
|
73
|
+
):
|
|
74
|
+
deduped_tensor_map = {}
|
|
75
|
+
|
|
76
|
+
for _, tensor, _ in _get_states(exported_programs, signatures):
|
|
77
|
+
unique_id = _tensor_unique_id(tensor)
|
|
78
|
+
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
|
|
79
|
+
|
|
80
|
+
state_dict = {}
|
|
81
|
+
for signature, tensor, input_spec in _get_states(
|
|
82
|
+
exported_programs, signatures
|
|
83
|
+
):
|
|
84
|
+
unique_id = _tensor_unique_id(tensor)
|
|
85
|
+
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
|
|
86
|
+
unique_id
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
return state_dict
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import tempfile
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
18
|
+
|
|
19
|
+
from ai_edge_torch import odml_torch
|
|
20
|
+
from ai_edge_torch._convert import conversion_utils
|
|
21
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
22
|
+
from ai_edge_torch.lowertools import common_utils
|
|
23
|
+
from ai_edge_torch.odml_torch import export
|
|
24
|
+
from ai_edge_torch.odml_torch import export_utils
|
|
25
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
26
|
+
import tensorflow as tf
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
30
|
+
|
|
31
|
+
MlirBundle = odml_torch.export.MlirLowered
|
|
32
|
+
MergedBundle = list[odml_torch.export.MlirLowered]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def torch_dtype_to_tf(dtype):
|
|
36
|
+
return {
|
|
37
|
+
torch.double: tf.float64,
|
|
38
|
+
torch.float32: tf.float32,
|
|
39
|
+
torch.half: tf.float16,
|
|
40
|
+
torch.long: tf.int64,
|
|
41
|
+
torch.int32: tf.int32,
|
|
42
|
+
torch.int16: tf.int16,
|
|
43
|
+
torch.bool: tf.bool,
|
|
44
|
+
}.get(dtype)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_shape_with_dynamic(signature: export.VariableSignature):
|
|
48
|
+
return [
|
|
49
|
+
None if export_utils.is_torch_dynamic(s) else s for s in signature.shape
|
|
50
|
+
]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _extract_call_args(
|
|
54
|
+
bundle: export.MlirLowered,
|
|
55
|
+
args: Tuple[Any],
|
|
56
|
+
tf_state_dict: Dict[str, tf.Variable],
|
|
57
|
+
):
|
|
58
|
+
call_args = []
|
|
59
|
+
for sig in bundle.input_signature:
|
|
60
|
+
if sig.input_spec.is_user_input:
|
|
61
|
+
call_args.append(args[sig.input_spec.i])
|
|
62
|
+
elif sig.input_spec.is_parameter:
|
|
63
|
+
name = sig.input_spec.name
|
|
64
|
+
call_args.append(tf_state_dict[name])
|
|
65
|
+
return call_args
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _wrap_as_tf_func(bundle, tf_state_dict):
|
|
69
|
+
def inner(*args):
|
|
70
|
+
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
|
|
71
|
+
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
|
|
72
|
+
call_args = _extract_call_args(bundle, args, tf_state_dict)
|
|
73
|
+
return tfxla.call_module(
|
|
74
|
+
tuple(call_args),
|
|
75
|
+
version=5,
|
|
76
|
+
Tout=t_outs, # dtype information
|
|
77
|
+
Sout=s_outs, # Shape information
|
|
78
|
+
function_list=[],
|
|
79
|
+
module=bundle.module_bytecode,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
return inner
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _make_tf_signature(
|
|
86
|
+
input_signature: list[export.VariableSignature],
|
|
87
|
+
signature: signature_module.Signature,
|
|
88
|
+
) -> List[tf.TensorSpec]:
|
|
89
|
+
input_names = signature.flat_arg_names
|
|
90
|
+
user_input_signature = sorted(
|
|
91
|
+
[sig for sig in input_signature if sig.input_spec.is_user_input],
|
|
92
|
+
key=lambda sig: sig.input_spec.i,
|
|
93
|
+
)
|
|
94
|
+
tf_signature = []
|
|
95
|
+
|
|
96
|
+
for sig in user_input_signature:
|
|
97
|
+
shape = _get_shape_with_dynamic(sig)
|
|
98
|
+
tf_signature.append(
|
|
99
|
+
tf.TensorSpec(
|
|
100
|
+
shape=shape,
|
|
101
|
+
dtype=torch_dtype_to_tf(sig.dtype),
|
|
102
|
+
name=input_names[sig.input_spec.i],
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
return tf_signature
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def merged_bundle_to_tfl_model(
|
|
109
|
+
merged_bundle: MergedBundle,
|
|
110
|
+
signatures: list[signature_module.Signature],
|
|
111
|
+
*,
|
|
112
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
113
|
+
_tfl_converter_flags: dict = {},
|
|
114
|
+
):
|
|
115
|
+
tf_state_dict = {
|
|
116
|
+
k: tf.Variable(v, trainable=False)
|
|
117
|
+
for k, v in merged_bundle[0].state_dict.items()
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
tf_signatures = [
|
|
121
|
+
_make_tf_signature(bundle.input_signature, sig)
|
|
122
|
+
for bundle, sig in zip(merged_bundle, signatures)
|
|
123
|
+
]
|
|
124
|
+
tf_functions = [
|
|
125
|
+
_wrap_as_tf_func(bundle, tf_state_dict) for bundle in merged_bundle
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
tf_module = tf.Module()
|
|
129
|
+
tf_module.f = []
|
|
130
|
+
|
|
131
|
+
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
132
|
+
tf_module.f.append(
|
|
133
|
+
tf.function(
|
|
134
|
+
func,
|
|
135
|
+
input_signature=tf_sig,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
tf_module._variables = list(tf_state_dict.values())
|
|
140
|
+
|
|
141
|
+
tf_concrete_funcs = [
|
|
142
|
+
func.get_concrete_function(*tf_sig)
|
|
143
|
+
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
# We need to temporarily save since TFLite's from_concrete_functions does not
|
|
147
|
+
# allow providing names for each of the concrete functions.
|
|
148
|
+
with tempfile.TemporaryDirectory() as temp_dir_path:
|
|
149
|
+
tf.saved_model.save(
|
|
150
|
+
tf_module,
|
|
151
|
+
temp_dir_path,
|
|
152
|
+
signatures={
|
|
153
|
+
sig.name: tf_concrete_funcs[idx]
|
|
154
|
+
for idx, sig in enumerate(signatures)
|
|
155
|
+
},
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
159
|
+
converter._experimental_enable_composite_direct_lowering = True
|
|
160
|
+
|
|
161
|
+
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
|
162
|
+
|
|
163
|
+
tflite_model = converter.convert()
|
|
164
|
+
|
|
165
|
+
return tflite_model
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def exported_program_to_mlir_text(
|
|
169
|
+
exported_program: torch.export.ExportedProgram,
|
|
170
|
+
) -> str:
|
|
171
|
+
"""Converts a ExportedProgram to a MLIR text."""
|
|
172
|
+
return odml_torch.export.exported_program_to_mlir(exported_program).get_text(
|
|
173
|
+
enable_debug_info=True
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def exported_program_to_mlir(
|
|
178
|
+
exported_program: torch.export.ExportedProgram,
|
|
179
|
+
sample_args: tuple[torch.Tensor],
|
|
180
|
+
) -> export.MlirLowered:
|
|
181
|
+
"""Converts a ExportedProgram to a MlirLowered."""
|
|
182
|
+
return odml_torch.export.exported_program_to_mlir(exported_program)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def merge_mlir_bundles(
|
|
186
|
+
bundles: list[export.MlirLowered],
|
|
187
|
+
signatures: list[signature_module.Signature],
|
|
188
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
189
|
+
) -> MergedBundle:
|
|
190
|
+
"""Merges a list of MlirLowered into one."""
|
|
191
|
+
state_dict = common_utils.gather_state_dict(exported_programs, signatures)
|
|
192
|
+
|
|
193
|
+
merged_bundle = bundles.copy()
|
|
194
|
+
for bundle, signature in zip(merged_bundle, signatures):
|
|
195
|
+
bundle.state_dict = state_dict
|
|
196
|
+
|
|
197
|
+
for var_sig in bundle.input_signature:
|
|
198
|
+
if var_sig.input_spec.is_parameter:
|
|
199
|
+
var_sig.input_spec.name = signature.name + "_" + var_sig.input_spec.name
|
|
200
|
+
|
|
201
|
+
return merged_bundle
|
|
@@ -13,19 +13,21 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import collections
|
|
17
16
|
import copy
|
|
18
17
|
from dataclasses import dataclass
|
|
19
18
|
import gc
|
|
20
19
|
import itertools
|
|
21
20
|
import logging
|
|
22
21
|
import tempfile
|
|
23
|
-
from typing import Any, Dict,
|
|
22
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
24
23
|
|
|
24
|
+
from ai_edge_torch import model
|
|
25
|
+
from ai_edge_torch._convert import conversion_utils
|
|
26
|
+
from ai_edge_torch._convert import signature as signature_module
|
|
25
27
|
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
28
|
+
from ai_edge_torch.lowertools import common_utils
|
|
26
29
|
from ai_edge_torch.quantize import quant_config as qcfg
|
|
27
30
|
import torch
|
|
28
|
-
import torch.utils._pytree as pytree
|
|
29
31
|
from torch_xla import stablehlo
|
|
30
32
|
|
|
31
33
|
try:
|
|
@@ -41,92 +43,11 @@ except ImportError:
|
|
|
41
43
|
)
|
|
42
44
|
raise
|
|
43
45
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@dataclass
|
|
48
|
-
class Signature:
|
|
49
|
-
name: str
|
|
50
|
-
module: torch.nn.Module
|
|
51
|
-
sample_args: tuple[torch.Tensor]
|
|
52
|
-
sample_kwargs: dict[str, torch.Tensor]
|
|
53
|
-
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def _normalized_sample_args_kwargs(self):
|
|
57
|
-
args, kwargs = self.sample_args, self.sample_kwargs
|
|
58
|
-
if args is not None:
|
|
59
|
-
if not isinstance(args, tuple):
|
|
60
|
-
# TODO(b/352584188): Check value types
|
|
61
|
-
raise ValueError("sample_args must be a tuple of torch tensors.")
|
|
62
|
-
if kwargs is not None:
|
|
63
|
-
if not isinstance(kwargs, dict) or not all(
|
|
64
|
-
isinstance(key, str) for key in kwargs.keys()
|
|
65
|
-
):
|
|
66
|
-
# TODO(b/352584188): Check value types
|
|
67
|
-
raise ValueError("sample_kwargs must be a dict of string to tensor.")
|
|
68
|
-
|
|
69
|
-
args = args if args is not None else tuple()
|
|
70
|
-
kwargs = kwargs if kwargs is not None else {}
|
|
71
|
-
return args, kwargs
|
|
72
|
-
|
|
73
|
-
@property
|
|
74
|
-
def flat_arg_names(self) -> list[str]:
|
|
75
|
-
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
|
|
76
|
-
args_spec, kwargs_spec = spec.children_specs
|
|
77
|
-
|
|
78
|
-
names = []
|
|
79
|
-
for i in range(args_spec.num_leaves):
|
|
80
|
-
names.append(f"args_{i}")
|
|
81
|
-
|
|
82
|
-
kwargs_names = self._flat_kwarg_names(
|
|
83
|
-
kwargs_spec.children_specs, kwargs_spec.context
|
|
84
|
-
)
|
|
85
|
-
names.extend(kwargs_names)
|
|
86
|
-
return names
|
|
87
|
-
|
|
88
|
-
def _flat_kwarg_names(self, specs, context) -> List[str]:
|
|
89
|
-
flat_names = []
|
|
90
|
-
if context is None:
|
|
91
|
-
for i, spec in enumerate(specs):
|
|
92
|
-
if spec.children_specs:
|
|
93
|
-
flat_names.extend([
|
|
94
|
-
f"{i}_{name}"
|
|
95
|
-
for name in self._flat_kwarg_names(
|
|
96
|
-
spec.children_specs, spec.context
|
|
97
|
-
)
|
|
98
|
-
])
|
|
99
|
-
else:
|
|
100
|
-
flat_names.append(f"{i}")
|
|
101
|
-
else:
|
|
102
|
-
flat_ctx = self._flatten_list(context)
|
|
103
|
-
for prefix, spec in zip(flat_ctx, specs):
|
|
104
|
-
leaf_flat_names = self._flat_kwarg_names(
|
|
105
|
-
spec.children_specs, spec.context
|
|
106
|
-
)
|
|
107
|
-
if leaf_flat_names:
|
|
108
|
-
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
109
|
-
else:
|
|
110
|
-
flat_names.append(prefix)
|
|
111
|
-
|
|
112
|
-
return flat_names
|
|
113
|
-
|
|
114
|
-
def _flatten_list(self, l: List) -> List:
|
|
115
|
-
flattened = []
|
|
116
|
-
for item in l:
|
|
117
|
-
if isinstance(item, list):
|
|
118
|
-
flattened.extend(self._flatten_list(item))
|
|
119
|
-
else:
|
|
120
|
-
flattened.append(item)
|
|
121
|
-
return flattened
|
|
122
|
-
|
|
123
|
-
@property
|
|
124
|
-
def flat_args(self) -> tuple[Any]:
|
|
125
|
-
args, kwargs = self._normalized_sample_args_kwargs
|
|
126
|
-
return tuple([*args, *kwargs.values()])
|
|
46
|
+
MlirBundle = stablehlo.StableHLOModelBundle
|
|
47
|
+
MergedBundle = stablehlo.StableHLOModelBundle
|
|
127
48
|
|
|
128
49
|
|
|
129
|
-
def
|
|
50
|
+
def exported_program_to_mlir(
|
|
130
51
|
exported_program: torch.export.ExportedProgram,
|
|
131
52
|
sample_args: tuple[torch.Tensor],
|
|
132
53
|
) -> stablehlo.StableHLOModelBundle:
|
|
@@ -141,81 +62,12 @@ def exported_program_to_stablehlo_bundle(
|
|
|
141
62
|
)._bundle
|
|
142
63
|
|
|
143
64
|
|
|
144
|
-
def
|
|
145
|
-
if not torch_tensor.is_contiguous():
|
|
146
|
-
torch_tensor = torch_tensor.contiguous()
|
|
147
|
-
|
|
148
|
-
try:
|
|
149
|
-
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
150
|
-
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
151
|
-
except Exception:
|
|
152
|
-
logging.info(
|
|
153
|
-
"Can not use dlpack to convert torch tensors. Falling back to numpy."
|
|
154
|
-
)
|
|
155
|
-
nparray = torch_tensor.cpu().detach().numpy()
|
|
156
|
-
tf_tensor = tf.convert_to_tensor(nparray)
|
|
157
|
-
|
|
158
|
-
return tf_tensor
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def _get_states(
|
|
162
|
-
exported_programs: list[torch.export.ExportedProgram],
|
|
163
|
-
signatures: list[Signature],
|
|
164
|
-
):
|
|
165
|
-
for exported_program, signature in zip(exported_programs, signatures):
|
|
166
|
-
args, _ = exported_program.example_inputs
|
|
167
|
-
# Calling this to get **all** the state including model buffers.
|
|
168
|
-
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
|
|
169
|
-
for tensor, input_spec in zip(
|
|
170
|
-
_flat_input_args, exported_program.graph_signature.input_specs
|
|
171
|
-
):
|
|
172
|
-
# Only interested in Tensors that are part of the state (and not user input).
|
|
173
|
-
if (
|
|
174
|
-
not isinstance(tensor, torch.Tensor)
|
|
175
|
-
or input_spec.kind
|
|
176
|
-
== torch.export.graph_signature.InputKind.USER_INPUT
|
|
177
|
-
):
|
|
178
|
-
continue
|
|
179
|
-
yield signature, tensor, input_spec
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def _tensor_unique_id(tensor: torch.Tensor):
|
|
183
|
-
return (
|
|
184
|
-
str(tensor.device),
|
|
185
|
-
tensor.shape,
|
|
186
|
-
tensor.stride(),
|
|
187
|
-
tensor.untyped_storage().data_ptr(),
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
def _gather_state_dict(
|
|
192
|
-
exported_programs: list[torch.export.ExportedProgram],
|
|
193
|
-
signatures: list[Signature],
|
|
194
|
-
):
|
|
195
|
-
deduped_tensor_map = {}
|
|
196
|
-
|
|
197
|
-
for _, tensor, _ in _get_states(exported_programs, signatures):
|
|
198
|
-
unique_id = _tensor_unique_id(tensor)
|
|
199
|
-
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
|
|
200
|
-
|
|
201
|
-
state_dict = {}
|
|
202
|
-
for signature, tensor, input_spec in _get_states(
|
|
203
|
-
exported_programs, signatures
|
|
204
|
-
):
|
|
205
|
-
unique_id = _tensor_unique_id(tensor)
|
|
206
|
-
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
|
|
207
|
-
unique_id
|
|
208
|
-
]
|
|
209
|
-
|
|
210
|
-
return state_dict
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
def merge_stablehlo_bundles(
|
|
65
|
+
def merge_mlir_bundles(
|
|
214
66
|
bundles: list[stablehlo.StableHLOModelBundle],
|
|
215
|
-
signatures: list[Signature],
|
|
67
|
+
signatures: list[signature_module.Signature],
|
|
216
68
|
exported_programs: list[torch.export.ExportedProgram],
|
|
217
69
|
) -> stablehlo.StableHLOGraphModule:
|
|
218
|
-
state_dict =
|
|
70
|
+
state_dict = common_utils.gather_state_dict(exported_programs, signatures)
|
|
219
71
|
|
|
220
72
|
new_bundle = stablehlo.StableHLOModelBundle(
|
|
221
73
|
state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
|
|
@@ -232,7 +84,7 @@ def merge_stablehlo_bundles(
|
|
|
232
84
|
loc.name = signature.name + "_" + loc.name
|
|
233
85
|
new_bundle.stablehlo_funcs.append(func)
|
|
234
86
|
new_bundle.additional_constants.extend(bundle.additional_constants)
|
|
235
|
-
return
|
|
87
|
+
return new_bundle
|
|
236
88
|
|
|
237
89
|
|
|
238
90
|
def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
|
|
@@ -264,19 +116,15 @@ def _wrap_as_tf_func(
|
|
|
264
116
|
|
|
265
117
|
|
|
266
118
|
def _make_tf_function(
|
|
267
|
-
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
268
119
|
bundle: stablehlo.StableHLOModelBundle = None,
|
|
269
120
|
):
|
|
270
|
-
bundle =
|
|
271
|
-
return [
|
|
272
|
-
_wrap_as_tf_func(func, bundle)
|
|
273
|
-
for func in shlo_graph_module._bundle.stablehlo_funcs
|
|
274
|
-
]
|
|
121
|
+
bundle = bundle if bundle is None else bundle
|
|
122
|
+
return [_wrap_as_tf_func(func, bundle) for func in bundle.stablehlo_funcs]
|
|
275
123
|
|
|
276
124
|
|
|
277
125
|
def _make_tf_signature(
|
|
278
126
|
meta: stablehlo.StableHLOFunctionMeta,
|
|
279
|
-
signature: Signature,
|
|
127
|
+
signature: signature_module.Signature,
|
|
280
128
|
) -> list[tf.TensorSpec]:
|
|
281
129
|
input_names = signature.flat_arg_names
|
|
282
130
|
input_pos_to_spec = {
|
|
@@ -305,60 +153,33 @@ def _make_tf_signature(
|
|
|
305
153
|
return ret
|
|
306
154
|
|
|
307
155
|
|
|
308
|
-
def
|
|
309
|
-
|
|
310
|
-
):
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
target_obj = converter
|
|
316
|
-
for idx in range(len(path) - 2):
|
|
317
|
-
target_obj = getattr(target_obj, path[idx])
|
|
318
|
-
|
|
319
|
-
setattr(target_obj, path[-2], path[-1])
|
|
320
|
-
|
|
321
|
-
def _iterate_dict_tree(flags_dict: dict, path: list):
|
|
322
|
-
for key, value in flags_dict.items():
|
|
323
|
-
path.append(key)
|
|
324
|
-
if isinstance(value, dict):
|
|
325
|
-
_iterate_dict_tree(value, path)
|
|
326
|
-
else:
|
|
327
|
-
path.append(value)
|
|
328
|
-
_set_converter_flag(path)
|
|
329
|
-
path.pop()
|
|
330
|
-
path.pop()
|
|
331
|
-
|
|
332
|
-
_iterate_dict_tree(tfl_converter_flags, [])
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def _set_tfl_converter_quant_flags(
|
|
336
|
-
converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
|
|
337
|
-
):
|
|
338
|
-
if quant_config is not None:
|
|
339
|
-
quantizer_mode = quant_config._quantizer_mode
|
|
340
|
-
if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
|
|
341
|
-
converter._experimental_qdq_conversion_mode = "DYNAMIC"
|
|
342
|
-
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
|
|
343
|
-
converter._experimental_qdq_conversion_mode = "STATIC"
|
|
156
|
+
def exported_program_to_mlir_text(
|
|
157
|
+
exported_program: torch.export.ExportedProgram,
|
|
158
|
+
) -> str:
|
|
159
|
+
"""Converts a ExportedProgram to a MLIR text."""
|
|
160
|
+
return stablehlo.exported_program_to_stablehlo(
|
|
161
|
+
exported_program
|
|
162
|
+
).get_stablehlo_text()
|
|
344
163
|
|
|
345
164
|
|
|
346
|
-
def
|
|
347
|
-
|
|
348
|
-
signatures: list[Signature],
|
|
165
|
+
def merged_bundle_to_tfl_model(
|
|
166
|
+
bundle: stablehlo.StableHLOModelBundle,
|
|
167
|
+
signatures: list[signature_module.Signature],
|
|
349
168
|
*,
|
|
350
169
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
351
170
|
_tfl_converter_flags: dict = {},
|
|
352
171
|
) -> None:
|
|
353
172
|
"""Converts a StableHLOGraphModule to a tflite model.
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
173
|
+
|
|
174
|
+
Args: shlo_bundle - model to export and save
|
|
175
|
+
|
|
176
|
+
signatures: List of signatures from which names of the signatures is
|
|
177
|
+
extracted.
|
|
357
178
|
quant_config: User-defined quantization method and scheme of the model.
|
|
358
|
-
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
|
179
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the
|
|
180
|
+
underlying tflite converter.
|
|
359
181
|
"""
|
|
360
182
|
|
|
361
|
-
bundle = shlo_graph_module._bundle
|
|
362
183
|
tf_module = tf.Module()
|
|
363
184
|
bundle.state_dict = {
|
|
364
185
|
k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
|
|
@@ -371,7 +192,7 @@ def convert_stablehlo_to_tflite(
|
|
|
371
192
|
for func, sig in zip(bundle.stablehlo_funcs, signatures)
|
|
372
193
|
)
|
|
373
194
|
|
|
374
|
-
tf_functions = _make_tf_function(
|
|
195
|
+
tf_functions = _make_tf_function(bundle)
|
|
375
196
|
|
|
376
197
|
tf_module.f = []
|
|
377
198
|
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
@@ -413,7 +234,7 @@ def convert_stablehlo_to_tflite(
|
|
|
413
234
|
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
|
414
235
|
converter._experimental_enable_composite_direct_lowering = True
|
|
415
236
|
|
|
416
|
-
|
|
237
|
+
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
|
417
238
|
if (
|
|
418
239
|
quant_config is not None
|
|
419
240
|
and quant_config._quantizer_mode
|
|
@@ -423,7 +244,7 @@ def convert_stablehlo_to_tflite(
|
|
|
423
244
|
quant_config.generative_recipe
|
|
424
245
|
)
|
|
425
246
|
|
|
426
|
-
|
|
247
|
+
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
|
427
248
|
|
|
428
249
|
tflite_model = converter.convert()
|
|
429
250
|
|
ai_edge_torch/model.py
CHANGED
|
@@ -15,17 +15,18 @@
|
|
|
15
15
|
|
|
16
16
|
"""Represents an ai_edge_torch model.
|
|
17
17
|
|
|
18
|
-
PyTorch models can be converted to this representation through
|
|
18
|
+
PyTorch models can be converted to this representation through
|
|
19
|
+
`ai_edge_torch.convert`.
|
|
19
20
|
"""
|
|
20
21
|
from __future__ import annotations
|
|
21
22
|
|
|
22
23
|
import abc
|
|
23
24
|
|
|
24
|
-
from ai_edge_torch.convert import conversion_utils as cutils
|
|
25
|
-
import numpy as np
|
|
26
25
|
import numpy.typing as npt
|
|
27
26
|
import tensorflow as tf
|
|
28
27
|
|
|
28
|
+
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
29
|
+
|
|
29
30
|
|
|
30
31
|
class Model(abc.ABC):
|
|
31
32
|
"""Represents and edge model."""
|
|
@@ -34,7 +35,7 @@ class Model(abc.ABC):
|
|
|
34
35
|
def __call__(
|
|
35
36
|
self,
|
|
36
37
|
*args: npt.ArrayLike,
|
|
37
|
-
signature_name: str =
|
|
38
|
+
signature_name: str = DEFAULT_SIGNATURE_NAME,
|
|
38
39
|
**kwargs,
|
|
39
40
|
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
40
41
|
raise NotImplementedError()
|
|
@@ -66,18 +67,22 @@ class TfLiteModel(Model):
|
|
|
66
67
|
def __call__(
|
|
67
68
|
self,
|
|
68
69
|
*args: npt.ArrayLike,
|
|
69
|
-
signature_name: str =
|
|
70
|
+
signature_name: str = DEFAULT_SIGNATURE_NAME,
|
|
70
71
|
**kwargs,
|
|
71
72
|
) -> npt.ArrayLike | tuple[npt.ArrayLike]:
|
|
72
73
|
"""Runs inference on the edge model using the provided arguments.
|
|
73
74
|
|
|
74
75
|
Args:
|
|
75
76
|
*args: The arguments to be passed to the model for inference.
|
|
76
|
-
**kwargs: The arguments with specific names to be passed to the model for
|
|
77
|
-
|
|
78
|
-
|
|
77
|
+
**kwargs: The arguments with specific names to be passed to the model for
|
|
78
|
+
inference.
|
|
79
|
+
signature_name: The name of the signature to be used for inference. The
|
|
80
|
+
default signature is used if not provided.
|
|
79
81
|
"""
|
|
80
|
-
interpreter = tf.lite.Interpreter(
|
|
82
|
+
interpreter = tf.lite.Interpreter(
|
|
83
|
+
model_content=self._tflite_model,
|
|
84
|
+
experimental_default_delegate_latest_features=True,
|
|
85
|
+
)
|
|
81
86
|
interpreter.allocate_tensors()
|
|
82
87
|
|
|
83
88
|
signature_list = interpreter.get_signature_list()
|