ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
from .convert.converter import convert
|
|
17
|
+
from .convert.converter import signature
|
|
18
|
+
from .convert.to_channel_last_io import to_channel_last_io
|
|
19
|
+
from .model import Model
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load(path: str) -> Model:
|
|
23
|
+
"""Imports an ai_edge_torch model from disk.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
path: The path to the serialized ai_edge_torch model.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
An ai_edge_torch.model.Model object.
|
|
30
|
+
"""
|
|
31
|
+
return Model.load(path)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import gc
|
|
17
|
+
import logging
|
|
18
|
+
import os
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from torch.export import ExportedProgram
|
|
23
|
+
from torch_xla import stablehlo
|
|
24
|
+
|
|
25
|
+
from ai_edge_torch import model
|
|
26
|
+
from ai_edge_torch.convert import conversion_utils as cutils
|
|
27
|
+
from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
|
|
28
|
+
from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
|
|
29
|
+
from ai_edge_torch.convert.fx_passes import CanonicalizePass
|
|
30
|
+
from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
|
|
31
|
+
from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
|
|
32
|
+
from ai_edge_torch.convert.fx_passes import run_passes
|
|
33
|
+
from ai_edge_torch.generative.fx_passes import run_generative_passes
|
|
34
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
35
|
+
|
|
36
|
+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _run_convert_passes(
|
|
40
|
+
exported_program: ExportedProgram,
|
|
41
|
+
) -> ExportedProgram:
|
|
42
|
+
exported_program = run_generative_passes(exported_program)
|
|
43
|
+
return run_passes(
|
|
44
|
+
exported_program,
|
|
45
|
+
[
|
|
46
|
+
BuildInterpolateCompositePass(),
|
|
47
|
+
CanonicalizePass(),
|
|
48
|
+
OptimizeLayoutTransposesPass(),
|
|
49
|
+
CanonicalizePass(),
|
|
50
|
+
BuildAtenCompositePass(),
|
|
51
|
+
CanonicalizePass(),
|
|
52
|
+
InjectMlirDebuginfoPass(),
|
|
53
|
+
CanonicalizePass(),
|
|
54
|
+
],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _warn_training_modules(signatures: list[cutils.Signature]):
|
|
59
|
+
for sig in signatures:
|
|
60
|
+
if not sig.module.training:
|
|
61
|
+
continue
|
|
62
|
+
|
|
63
|
+
message = (
|
|
64
|
+
"Your model {sig_name}is converted in training mode. "
|
|
65
|
+
"Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
|
|
66
|
+
)
|
|
67
|
+
if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
|
|
68
|
+
# User does not specify any signature names explicitly.
|
|
69
|
+
message = message.format(sig_name="")
|
|
70
|
+
else:
|
|
71
|
+
message = message.format(sig_name=f'"{sig.name}" ')
|
|
72
|
+
|
|
73
|
+
logging.warn(message)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def convert_signatures(
|
|
77
|
+
signatures: list[cutils.Signature],
|
|
78
|
+
*,
|
|
79
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
80
|
+
_tfl_converter_flags: dict = {},
|
|
81
|
+
) -> model.TfLiteModel:
|
|
82
|
+
"""Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
|
|
83
|
+
Args:
|
|
84
|
+
signatures: The list of 'Signature' objects containing PyTorch modules to be converted.
|
|
85
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
86
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
|
|
87
|
+
"""
|
|
88
|
+
_warn_training_modules(signatures)
|
|
89
|
+
|
|
90
|
+
exported_programs: torch.export.ExportedProgram = [
|
|
91
|
+
torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
|
92
|
+
for sig in signatures
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Apply default fx passes
|
|
96
|
+
exported_programs = list(map(_run_convert_passes, exported_programs))
|
|
97
|
+
shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
|
|
98
|
+
cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
|
|
99
|
+
for exported, sig in zip(exported_programs, signatures)
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
|
|
103
|
+
cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
|
|
104
|
+
)
|
|
105
|
+
del exported_programs
|
|
106
|
+
del shlo_bundles
|
|
107
|
+
|
|
108
|
+
gc.collect()
|
|
109
|
+
|
|
110
|
+
tflite_model = cutils.convert_stablehlo_to_tflite(
|
|
111
|
+
merged_shlo_graph_module,
|
|
112
|
+
signatures,
|
|
113
|
+
quant_config=quant_config,
|
|
114
|
+
_tfl_converter_flags=_tfl_converter_flags,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
return model.TfLiteModel(tflite_model)
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import collections
|
|
17
|
+
import copy
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
import gc
|
|
20
|
+
import itertools
|
|
21
|
+
import logging
|
|
22
|
+
import tempfile
|
|
23
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
import torch.utils._pytree as pytree
|
|
27
|
+
from torch_xla import stablehlo
|
|
28
|
+
|
|
29
|
+
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
30
|
+
from ai_edge_torch.quantize import quant_config as qcfg
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
import tensorflow as tf
|
|
34
|
+
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
35
|
+
|
|
36
|
+
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
|
|
37
|
+
except ImportError:
|
|
38
|
+
logging.error(
|
|
39
|
+
"This module needs tensorflow with xla support.\n"
|
|
40
|
+
"Please install tensorflow with `pip install tf-nightly`.\n"
|
|
41
|
+
)
|
|
42
|
+
raise
|
|
43
|
+
|
|
44
|
+
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class Signature:
|
|
49
|
+
name: str
|
|
50
|
+
module: torch.nn.Module
|
|
51
|
+
sample_args: tuple[torch.Tensor]
|
|
52
|
+
sample_kwargs: dict[str, torch.Tensor]
|
|
53
|
+
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def _normalized_sample_args_kwargs(self):
|
|
57
|
+
args, kwargs = self.sample_args, self.sample_kwargs
|
|
58
|
+
if args is not None:
|
|
59
|
+
if not isinstance(args, tuple):
|
|
60
|
+
# TODO(b/352584188): Check value types
|
|
61
|
+
raise ValueError("sample_args must be a tuple of torch tensors.")
|
|
62
|
+
if kwargs is not None:
|
|
63
|
+
if not isinstance(kwargs, dict) or not all(
|
|
64
|
+
isinstance(key, str) for key in kwargs.keys()
|
|
65
|
+
):
|
|
66
|
+
# TODO(b/352584188): Check value types
|
|
67
|
+
raise ValueError("sample_kwargs must be a dict of string to tensor.")
|
|
68
|
+
|
|
69
|
+
args = args if args is not None else tuple()
|
|
70
|
+
kwargs = kwargs if kwargs is not None else {}
|
|
71
|
+
return args, kwargs
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def flat_arg_names(self) -> list[str]:
|
|
75
|
+
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
|
|
76
|
+
args_spec, kwargs_spec = spec.children_specs
|
|
77
|
+
|
|
78
|
+
names = []
|
|
79
|
+
for i in range(args_spec.num_leaves):
|
|
80
|
+
names.append(f"args_{i}")
|
|
81
|
+
|
|
82
|
+
dict_context = (
|
|
83
|
+
kwargs_spec.context
|
|
84
|
+
if kwargs_spec.type is not collections.defaultdict
|
|
85
|
+
# ignore mismatch of `default_factory` for defaultdict
|
|
86
|
+
else kwargs_spec.context[1]
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
|
|
90
|
+
if value_spec.num_leaves == 1:
|
|
91
|
+
names.append(name)
|
|
92
|
+
else:
|
|
93
|
+
# value_spec.num_leaves may be greater than 1 when the value is a (nested)
|
|
94
|
+
# tuple of tensors. We haven't decided how we should support flattenable
|
|
95
|
+
# tensor containers as inputs.
|
|
96
|
+
# TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
|
|
97
|
+
for i in range(value_spec.num_leaves):
|
|
98
|
+
names.append(f"{name}_{i}")
|
|
99
|
+
return names
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def flat_args(self) -> tuple[torch.Tensor]:
|
|
103
|
+
return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def exported_program_to_stablehlo_bundle(
|
|
107
|
+
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
|
|
108
|
+
) -> stablehlo.StableHLOModelBundle:
|
|
109
|
+
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
|
|
110
|
+
# to a numpy array which would lead to memory bloat. This means that the state_dict
|
|
111
|
+
# in the returned bundle is going to be empty.
|
|
112
|
+
return stablehlo.exported_program_to_stablehlo(
|
|
113
|
+
exported_program,
|
|
114
|
+
stablehlo.StableHLOExportOptions(
|
|
115
|
+
override_tracing_arguments=sample_args, export_weights=False
|
|
116
|
+
),
|
|
117
|
+
)._bundle
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
|
|
121
|
+
if not torch_tensor.is_contiguous():
|
|
122
|
+
torch_tensor = torch_tensor.contiguous()
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
126
|
+
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
127
|
+
except Exception:
|
|
128
|
+
logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
|
|
129
|
+
nparray = torch_tensor.cpu().detach().numpy()
|
|
130
|
+
tf_tensor = tf.convert_to_tensor(nparray)
|
|
131
|
+
|
|
132
|
+
return tf_tensor
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _get_states(
|
|
136
|
+
exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
|
|
137
|
+
):
|
|
138
|
+
for exported_program, signature in zip(exported_programs, signatures):
|
|
139
|
+
args, _ = exported_program.example_inputs
|
|
140
|
+
# Calling this to get **all** the state including model buffers.
|
|
141
|
+
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
|
|
142
|
+
for tensor, input_spec in zip(
|
|
143
|
+
_flat_input_args, exported_program.graph_signature.input_specs
|
|
144
|
+
):
|
|
145
|
+
# Only interested in Tensors that are part of the state (and not user input).
|
|
146
|
+
if (
|
|
147
|
+
not isinstance(tensor, torch.Tensor)
|
|
148
|
+
or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
149
|
+
):
|
|
150
|
+
continue
|
|
151
|
+
yield signature, tensor, input_spec
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _tensor_unique_id(tensor: torch.Tensor):
|
|
155
|
+
return (
|
|
156
|
+
str(tensor.device),
|
|
157
|
+
tensor.shape,
|
|
158
|
+
tensor.stride(),
|
|
159
|
+
tensor.untyped_storage().data_ptr(),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _gather_state_dict(
|
|
164
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
165
|
+
signatures: list[Signature],
|
|
166
|
+
):
|
|
167
|
+
deduped_tensor_map = {}
|
|
168
|
+
|
|
169
|
+
for _, tensor, _ in _get_states(exported_programs, signatures):
|
|
170
|
+
unique_id = _tensor_unique_id(tensor)
|
|
171
|
+
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
|
|
172
|
+
|
|
173
|
+
state_dict = {}
|
|
174
|
+
for signature, tensor, input_spec in _get_states(exported_programs, signatures):
|
|
175
|
+
unique_id = _tensor_unique_id(tensor)
|
|
176
|
+
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
|
|
177
|
+
|
|
178
|
+
return state_dict
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def merge_stablehlo_bundles(
|
|
182
|
+
bundles: list[stablehlo.StableHLOModelBundle],
|
|
183
|
+
signatures: list[Signature],
|
|
184
|
+
exported_programs: list[torch.export.ExportedProgram],
|
|
185
|
+
) -> stablehlo.StableHLOGraphModule:
|
|
186
|
+
state_dict = _gather_state_dict(exported_programs, signatures)
|
|
187
|
+
|
|
188
|
+
new_bundle = stablehlo.StableHLOModelBundle(
|
|
189
|
+
state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
for bundle, signature in zip(bundles, signatures):
|
|
193
|
+
const_offset = len(new_bundle.additional_constants)
|
|
194
|
+
for func in bundle.stablehlo_funcs:
|
|
195
|
+
func.meta.name = signature.name + "_" + func.meta.name
|
|
196
|
+
for loc in func.meta.input_locations:
|
|
197
|
+
if loc.type_ == stablehlo.VariableType.CONSTANT:
|
|
198
|
+
loc.position += const_offset
|
|
199
|
+
elif loc.type_ == stablehlo.VariableType.PARAMETER:
|
|
200
|
+
loc.name = signature.name + "_" + loc.name
|
|
201
|
+
new_bundle.stablehlo_funcs.append(func)
|
|
202
|
+
new_bundle.additional_constants.extend(bundle.additional_constants)
|
|
203
|
+
return stablehlo.StableHLOGraphModule(new_bundle)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
|
|
207
|
+
shape = copy.copy(signature.shape)
|
|
208
|
+
for i in signature.dynamic_dims:
|
|
209
|
+
shape[i] = None
|
|
210
|
+
return shape
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _wrap_as_tf_func(
|
|
214
|
+
func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
|
|
215
|
+
):
|
|
216
|
+
def inner(*args):
|
|
217
|
+
type_info = [sig.dtype for sig in func.meta.output_signature]
|
|
218
|
+
shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
|
|
219
|
+
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
|
|
220
|
+
return tfxla.call_module(
|
|
221
|
+
tuple(call_args),
|
|
222
|
+
version=5,
|
|
223
|
+
Tout=type_info,
|
|
224
|
+
Sout=shape_info,
|
|
225
|
+
function_list=[],
|
|
226
|
+
module=func.bytecode,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
return inner
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def _make_tf_function(
|
|
233
|
+
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
234
|
+
bundle: stablehlo.StableHLOModelBundle = None,
|
|
235
|
+
):
|
|
236
|
+
bundle = shlo_graph_module._bundle if bundle is None else bundle
|
|
237
|
+
return [
|
|
238
|
+
_wrap_as_tf_func(func, bundle)
|
|
239
|
+
for func in shlo_graph_module._bundle.stablehlo_funcs
|
|
240
|
+
]
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _make_tf_signature(
|
|
244
|
+
meta: stablehlo.StableHLOFunctionMeta,
|
|
245
|
+
signature: Signature,
|
|
246
|
+
) -> list[tf.TensorSpec]:
|
|
247
|
+
input_names = signature.flat_arg_names
|
|
248
|
+
input_pos_to_spec = {
|
|
249
|
+
loc.position: spec
|
|
250
|
+
for loc, spec in itertools.chain(
|
|
251
|
+
zip(meta.input_locations, meta.input_signature), meta.unused_inputs
|
|
252
|
+
)
|
|
253
|
+
if loc.type_ == stablehlo.VariableType.INPUT_ARG
|
|
254
|
+
}
|
|
255
|
+
assert len(input_pos_to_spec) == len(input_names)
|
|
256
|
+
|
|
257
|
+
primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
|
|
258
|
+
ret: list[tf.TensorSpec] = []
|
|
259
|
+
for i, name in enumerate(input_names):
|
|
260
|
+
spec = input_pos_to_spec[i]
|
|
261
|
+
shape = _get_shape_with_dynamic(spec)
|
|
262
|
+
ret.append(
|
|
263
|
+
tf.TensorSpec(
|
|
264
|
+
shape=shape,
|
|
265
|
+
dtype=primitive_type_to_tf_type[spec.dtype]
|
|
266
|
+
if spec.dtype in primitive_type_to_tf_type
|
|
267
|
+
else spec.dtype,
|
|
268
|
+
name=name,
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
return ret
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _apply_tfl_backdoor_flags(
|
|
275
|
+
converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict
|
|
276
|
+
):
|
|
277
|
+
def _set_converter_flag(path: list):
|
|
278
|
+
if len(path) < 2:
|
|
279
|
+
raise ValueError("Expecting at least two values in the path.")
|
|
280
|
+
|
|
281
|
+
target_obj = converter
|
|
282
|
+
for idx in range(len(path) - 2):
|
|
283
|
+
target_obj = getattr(target_obj, path[idx])
|
|
284
|
+
|
|
285
|
+
setattr(target_obj, path[-2], path[-1])
|
|
286
|
+
|
|
287
|
+
def _iterate_dict_tree(flags_dict: dict, path: list):
|
|
288
|
+
for key, value in flags_dict.items():
|
|
289
|
+
path.append(key)
|
|
290
|
+
if isinstance(value, dict):
|
|
291
|
+
_iterate_dict_tree(value, path)
|
|
292
|
+
else:
|
|
293
|
+
path.append(value)
|
|
294
|
+
_set_converter_flag(path)
|
|
295
|
+
path.pop()
|
|
296
|
+
path.pop()
|
|
297
|
+
|
|
298
|
+
_iterate_dict_tree(tfl_converter_flags, [])
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def _set_tfl_converter_quant_flags(
|
|
302
|
+
converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
|
|
303
|
+
):
|
|
304
|
+
if quant_config is not None:
|
|
305
|
+
quantizer_mode = quant_config._quantizer_mode
|
|
306
|
+
if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
|
|
307
|
+
converter._experimental_qdq_conversion_mode = "DYNAMIC"
|
|
308
|
+
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
|
|
309
|
+
converter._experimental_qdq_conversion_mode = "STATIC"
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def convert_stablehlo_to_tflite(
|
|
313
|
+
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
314
|
+
signatures: list[Signature],
|
|
315
|
+
*,
|
|
316
|
+
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
317
|
+
_tfl_converter_flags: dict = {},
|
|
318
|
+
) -> None:
|
|
319
|
+
"""Converts a StableHLOGraphModule to a tflite model.
|
|
320
|
+
Args:
|
|
321
|
+
shlo_graph_module - model to export and save
|
|
322
|
+
signatures: List of signatures from which names of the signatures is extracted.
|
|
323
|
+
quant_config: User-defined quantization method and scheme of the model.
|
|
324
|
+
_tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
bundle = shlo_graph_module._bundle
|
|
328
|
+
tf_module = tf.Module()
|
|
329
|
+
bundle.state_dict = {
|
|
330
|
+
k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
|
|
331
|
+
}
|
|
332
|
+
bundle.additional_constants = [
|
|
333
|
+
tf.Variable(v, trainable=False) for v in bundle.additional_constants
|
|
334
|
+
]
|
|
335
|
+
tf_signatures: list[list[tf.TensorSpec]] = list(
|
|
336
|
+
_make_tf_signature(func.meta, sig)
|
|
337
|
+
for func, sig in zip(bundle.stablehlo_funcs, signatures)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
tf_functions = _make_tf_function(shlo_graph_module, bundle)
|
|
341
|
+
|
|
342
|
+
tf_module.f = []
|
|
343
|
+
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
344
|
+
tf_module.f.append(
|
|
345
|
+
tf.function(
|
|
346
|
+
func,
|
|
347
|
+
input_signature=tf_sig,
|
|
348
|
+
)
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
|
|
352
|
+
del bundle
|
|
353
|
+
gc.collect()
|
|
354
|
+
|
|
355
|
+
tf_concrete_funcs = [
|
|
356
|
+
func.get_concrete_function(*tf_sig)
|
|
357
|
+
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
|
358
|
+
]
|
|
359
|
+
|
|
360
|
+
# We need to temporarily save since TFLite's from_concrete_functions does not
|
|
361
|
+
# allow providing names for each of the concrete functions.
|
|
362
|
+
with tempfile.TemporaryDirectory() as temp_dir_path:
|
|
363
|
+
tf.saved_model.save(
|
|
364
|
+
tf_module,
|
|
365
|
+
temp_dir_path,
|
|
366
|
+
signatures={
|
|
367
|
+
sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
|
|
368
|
+
},
|
|
369
|
+
)
|
|
370
|
+
# Clean up intermediate memory early.
|
|
371
|
+
del tf_module
|
|
372
|
+
del tf_concrete_funcs
|
|
373
|
+
gc.collect()
|
|
374
|
+
|
|
375
|
+
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
376
|
+
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
|
377
|
+
converter._experimental_enable_composite_direct_lowering = True
|
|
378
|
+
|
|
379
|
+
_set_tfl_converter_quant_flags(converter, quant_config)
|
|
380
|
+
if (
|
|
381
|
+
quant_config is not None
|
|
382
|
+
and quant_config._quantizer_mode
|
|
383
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
384
|
+
):
|
|
385
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
|
386
|
+
quant_config.generative_recipe
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
_apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
|
|
390
|
+
|
|
391
|
+
tflite_model = converter.convert()
|
|
392
|
+
|
|
393
|
+
if (
|
|
394
|
+
quant_config is not None
|
|
395
|
+
and quant_config._quantizer_mode
|
|
396
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
397
|
+
):
|
|
398
|
+
tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
|
|
399
|
+
|
|
400
|
+
return tflite_model
|