ai-edge-torch-nightly 0.1.dev202405131930__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (91) hide show
  1. ai_edge_torch/__init__.py +30 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +330 -0
  5. ai_edge_torch/convert/converter.py +171 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +192 -0
  9. ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py +84 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +196 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +286 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +273 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +171 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/debug/__init__.py +16 -0
  27. ai_edge_torch/debug/culprit.py +423 -0
  28. ai_edge_torch/debug/test/__init__.py +14 -0
  29. ai_edge_torch/debug/test/test_culprit.py +133 -0
  30. ai_edge_torch/debug/utils.py +48 -0
  31. ai_edge_torch/experimental/__init__.py +14 -0
  32. ai_edge_torch/generative/__init__.py +14 -0
  33. ai_edge_torch/generative/examples/__init__.py +14 -0
  34. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  35. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  36. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  37. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  39. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  40. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  42. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  43. ai_edge_torch/generative/examples/t5/t5_attention.py +255 -0
  44. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  45. ai_edge_torch/generative/examples/test_models/toy_model.py +119 -0
  46. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  47. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  48. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  49. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  50. ai_edge_torch/generative/layers/__init__.py +14 -0
  51. ai_edge_torch/generative/layers/attention.py +288 -0
  52. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  53. ai_edge_torch/generative/layers/builder.py +103 -0
  54. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  55. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  56. ai_edge_torch/generative/layers/model_config.py +135 -0
  57. ai_edge_torch/generative/layers/normalization.py +62 -0
  58. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  59. ai_edge_torch/generative/quantize/__init__.py +14 -0
  60. ai_edge_torch/generative/quantize/example.py +45 -0
  61. ai_edge_torch/generative/quantize/quant_attrs.py +66 -0
  62. ai_edge_torch/generative/quantize/quant_recipe.py +106 -0
  63. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  64. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  65. ai_edge_torch/generative/quantize/supported_schemes.py +31 -0
  66. ai_edge_torch/generative/test/__init__.py +14 -0
  67. ai_edge_torch/generative/test/test_model_conversion.py +201 -0
  68. ai_edge_torch/generative/test/test_quantize.py +109 -0
  69. ai_edge_torch/generative/utilities/__init__.py +15 -0
  70. ai_edge_torch/generative/utilities/loader.py +290 -0
  71. ai_edge_torch/generative/utilities/t5_loader.py +467 -0
  72. ai_edge_torch/hlfb/__init__.py +16 -0
  73. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  74. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  75. ai_edge_torch/hlfb/mark_pattern/pattern.py +260 -0
  76. ai_edge_torch/hlfb/test/__init__.py +14 -0
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  79. ai_edge_torch/model.py +134 -0
  80. ai_edge_torch/quantize/__init__.py +16 -0
  81. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  83. ai_edge_torch/quantize/quant_config.py +85 -0
  84. ai_edge_torch/testing/__init__.py +14 -0
  85. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  86. ai_edge_torch/testing/model_coverage/model_coverage.py +126 -0
  87. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/LICENSE +202 -0
  88. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/METADATA +38 -0
  89. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/RECORD +91 -0
  90. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/WHEEL +5 -0
  91. ai_edge_torch_nightly-0.1.dev202405131930.dist-info/top_level.txt +1 -0
@@ -0,0 +1,30 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from .convert.converter import convert
17
+ from .convert.converter import signature
18
+ from .model import Model
19
+
20
+
21
+ def load(path: str) -> Model:
22
+ """Imports an ai_edge_torch model from disk.
23
+
24
+ Args:
25
+ path: The path to the serialized ai_edge_torch model.
26
+
27
+ Returns:
28
+ An ai_edge_torch.model.Model object.
29
+ """
30
+ return Model.load(path)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import gc
17
+ import logging
18
+ import os
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+ from torch_xla import stablehlo
24
+
25
+ from ai_edge_torch import model
26
+ from ai_edge_torch.convert import conversion_utils as cutils
27
+ from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
28
+ from ai_edge_torch.convert.fx_passes import BuildUpsampleBilinear2DCompositePass # NOQA
29
+ from ai_edge_torch.convert.fx_passes import CanonicalizePass
30
+ from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
31
+ from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
32
+ from ai_edge_torch.convert.fx_passes import run_passes
33
+ from ai_edge_torch.quantize import quant_config as qcfg
34
+
35
+ os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
36
+
37
+
38
+ def _run_convert_passes(
39
+ exported_program: ExportedProgram,
40
+ ) -> ExportedProgram:
41
+ return run_passes(
42
+ exported_program,
43
+ [
44
+ BuildUpsampleBilinear2DCompositePass(),
45
+ CanonicalizePass(),
46
+ OptimizeLayoutTransposesPass(),
47
+ CanonicalizePass(),
48
+ BuildAtenCompositePass(),
49
+ CanonicalizePass(),
50
+ InjectMlirDebuginfoPass(),
51
+ CanonicalizePass(),
52
+ ],
53
+ )
54
+
55
+
56
+ def _warn_training_modules(signatures: list[cutils.Signature]):
57
+ for sig in signatures:
58
+ if not sig.module.training:
59
+ continue
60
+
61
+ message = (
62
+ "Your model {sig_name}is converted in training mode. "
63
+ "Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
64
+ )
65
+ if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
66
+ # User does not specify any signature names explicitly.
67
+ message = message.format(sig_name="")
68
+ else:
69
+ message = message.format(sig_name=f'"{sig.name}" ')
70
+
71
+ logging.warn(message)
72
+
73
+
74
+ def convert_signatures(
75
+ signatures: list[cutils.Signature],
76
+ *,
77
+ quant_config: Optional[qcfg.QuantConfig] = None,
78
+ _tfl_converter_flags: dict = {},
79
+ ) -> model.TfLiteModel:
80
+ """Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
81
+ Args:
82
+ signatures: The list of 'Signature' objects containing PyTorch modules to be converted.
83
+ quant_config: User-defined quantization method and scheme of the model.
84
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
85
+ """
86
+ _warn_training_modules(signatures)
87
+
88
+ exported_programs: torch.export.ExportedProgram = [
89
+ torch.export.export(
90
+ sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
91
+ )
92
+ for sig in signatures
93
+ ]
94
+
95
+ # Apply default fx passes
96
+ exported_programs = list(map(_run_convert_passes, exported_programs))
97
+ shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
98
+ cutils.exported_program_to_stablehlo_bundle(exported, sig.sample_args)
99
+ for exported, sig in zip(exported_programs, signatures)
100
+ ]
101
+
102
+ merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
103
+ cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
104
+ )
105
+ del exported_programs
106
+ del shlo_bundles
107
+
108
+ gc.collect()
109
+
110
+ tflite_model = cutils.convert_stablehlo_to_tflite(
111
+ merged_shlo_graph_module,
112
+ signatures,
113
+ quant_config=quant_config,
114
+ _tfl_converter_flags=_tfl_converter_flags,
115
+ )
116
+
117
+ return model.TfLiteModel(tflite_model)
@@ -0,0 +1,330 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import copy
17
+ from dataclasses import dataclass
18
+ import gc
19
+ import itertools
20
+ import logging
21
+ import tempfile
22
+ from typing import Any, Dict, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch_xla import stablehlo
26
+
27
+ from ai_edge_torch.quantize import quant_config as qcfg
28
+
29
+ try:
30
+ import tensorflow as tf
31
+ from tensorflow.compiler.tf2xla.python import xla as tfxla
32
+
33
+ from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
34
+ except ImportError:
35
+ logging.error(
36
+ "This module needs tensorflow with xla support.\n"
37
+ "Please install tensorflow with `pip install tf-nightly`.\n"
38
+ )
39
+ raise
40
+
41
+ DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
42
+
43
+
44
+ @dataclass
45
+ class Signature:
46
+ name: str
47
+ module: torch.nn.Module
48
+ sample_args: tuple[torch.Tensor]
49
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
50
+
51
+
52
+ def exported_program_to_stablehlo_bundle(
53
+ exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
54
+ ) -> stablehlo.StableHLOModelBundle:
55
+ # Setting export_weights to False here so that pytorch/xla avoids copying the weights
56
+ # to a numpy array which would lead to memory bloat. This means that the state_dict
57
+ # in the returned bundle is going to be empty.
58
+ return stablehlo.exported_program_to_stablehlo(
59
+ exported_program,
60
+ stablehlo.StableHLOExportOptions(
61
+ override_tracing_arguments=sample_args, export_weights=False
62
+ ),
63
+ )._bundle
64
+
65
+
66
+ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
67
+ if not torch_tensor.is_contiguous():
68
+ torch_tensor = torch_tensor.contiguous()
69
+
70
+ try:
71
+ dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
72
+ tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
73
+ except Exception:
74
+ logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
75
+ nparray = torch_tensor.cpu().detach().numpy()
76
+ tf_tensor = tf.convert_to_tensor(nparray)
77
+
78
+ return tf_tensor
79
+
80
+
81
+ def _get_states(
82
+ exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
83
+ ):
84
+ for exported_program, signature in zip(exported_programs, signatures):
85
+ args, _ = exported_program.example_inputs
86
+ # Calling this to get **all** the state including model buffers.
87
+ _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
88
+ for tensor, input_spec in zip(
89
+ _flat_input_args, exported_program.graph_signature.input_specs
90
+ ):
91
+ # Only interested in Tensors that are part of the state (and not user input).
92
+ if (
93
+ not isinstance(tensor, torch.Tensor)
94
+ or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
95
+ ):
96
+ continue
97
+ yield signature, tensor, input_spec
98
+
99
+
100
+ def _tensor_unique_id(tensor: torch.Tensor):
101
+ return (
102
+ str(tensor.device),
103
+ tensor.shape,
104
+ tensor.stride(),
105
+ tensor.untyped_storage().data_ptr(),
106
+ )
107
+
108
+
109
+ def _gather_state_dict(
110
+ exported_programs: list[torch.export.ExportedProgram],
111
+ signatures: list[Signature],
112
+ ):
113
+ deduped_tensor_map = {}
114
+
115
+ for _, tensor, _ in _get_states(exported_programs, signatures):
116
+ unique_id = _tensor_unique_id(tensor)
117
+ deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
118
+
119
+ state_dict = {}
120
+ for signature, tensor, input_spec in _get_states(exported_programs, signatures):
121
+ unique_id = _tensor_unique_id(tensor)
122
+ state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
123
+
124
+ return state_dict
125
+
126
+
127
+ def merge_stablehlo_bundles(
128
+ bundles: list[stablehlo.StableHLOModelBundle],
129
+ signatures: list[Signature],
130
+ exported_programs: list[torch.export.ExportedProgram],
131
+ ) -> stablehlo.StableHLOGraphModule:
132
+ state_dict = _gather_state_dict(exported_programs, signatures)
133
+
134
+ new_bundle = stablehlo.StableHLOModelBundle(
135
+ state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
136
+ )
137
+
138
+ for bundle, signature in zip(bundles, signatures):
139
+ const_offset = len(new_bundle.additional_constants)
140
+ for func in bundle.stablehlo_funcs:
141
+ func.meta.name = signature.name + "_" + func.meta.name
142
+ for loc in func.meta.input_locations:
143
+ if loc.type_ == stablehlo.VariableType.CONSTANT:
144
+ loc.position += const_offset
145
+ elif loc.type_ == stablehlo.VariableType.PARAMETER:
146
+ loc.name = signature.name + "_" + loc.name
147
+ new_bundle.stablehlo_funcs.append(func)
148
+ new_bundle.additional_constants.extend(bundle.additional_constants)
149
+ return stablehlo.StableHLOGraphModule(new_bundle)
150
+
151
+
152
+ def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
153
+ shape = copy.copy(signature.shape)
154
+ for i in signature.dynamic_dims:
155
+ shape[i] = None
156
+ return shape
157
+
158
+
159
+ def _wrap_as_tf_func(
160
+ func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
161
+ ):
162
+ def inner(*args):
163
+ type_info = [sig.dtype for sig in func.meta.output_signature]
164
+ shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
165
+ call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
166
+ return tfxla.call_module(
167
+ tuple(call_args),
168
+ version=5,
169
+ Tout=type_info,
170
+ Sout=shape_info,
171
+ function_list=[],
172
+ module=func.bytecode,
173
+ )
174
+
175
+ return inner
176
+
177
+
178
+ def _make_tf_function(
179
+ shlo_graph_module: stablehlo.StableHLOGraphModule,
180
+ bundle: stablehlo.StableHLOModelBundle = None,
181
+ ):
182
+ bundle = shlo_graph_module._bundle if bundle is None else bundle
183
+ return [
184
+ _wrap_as_tf_func(func, bundle)
185
+ for func in shlo_graph_module._bundle.stablehlo_funcs
186
+ ]
187
+
188
+
189
+ def _make_tf_signature(
190
+ meta: stablehlo.StableHLOFunctionMeta,
191
+ ) -> list[tf.TensorSpec]:
192
+ input_pos_to_spec = {
193
+ loc.position: spec
194
+ for loc, spec in itertools.chain(
195
+ zip(meta.input_locations, meta.input_signature), meta.unused_inputs
196
+ )
197
+ if loc.type_ == stablehlo.VariableType.INPUT_ARG
198
+ }
199
+ primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
200
+ ret: list[tf.TensorSpec] = []
201
+ for i in range(len(input_pos_to_spec)):
202
+ spec = input_pos_to_spec[i]
203
+ shape = _get_shape_with_dynamic(spec)
204
+ ret.append(
205
+ tf.TensorSpec(
206
+ shape=shape,
207
+ dtype=primitive_type_to_tf_type[spec.dtype]
208
+ if spec.dtype in primitive_type_to_tf_type
209
+ else spec.dtype,
210
+ name=f"args_{i}",
211
+ )
212
+ )
213
+ return ret
214
+
215
+
216
+ def _apply_tfl_backdoor_flags(
217
+ converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict
218
+ ):
219
+ def _set_converter_flag(path: list):
220
+ if len(path) < 2:
221
+ raise ValueError("Expecting at least two values in the path.")
222
+
223
+ target_obj = converter
224
+ for idx in range(len(path) - 2):
225
+ target_obj = getattr(target_obj, path[idx])
226
+
227
+ setattr(target_obj, path[-2], path[-1])
228
+
229
+ def _iterate_dict_tree(flags_dict: dict, path: list):
230
+ for key, value in flags_dict.items():
231
+ path.append(key)
232
+ if isinstance(value, dict):
233
+ _iterate_dict_tree(value, path)
234
+ else:
235
+ path.append(value)
236
+ _set_converter_flag(path)
237
+ path.pop()
238
+ path.pop()
239
+
240
+ _iterate_dict_tree(tfl_converter_flags, [])
241
+
242
+
243
+ def _set_tfl_converter_quant_flags(
244
+ converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
245
+ ):
246
+ if quant_config is not None:
247
+ quantizer_mode = quant_config._quantizer_mode
248
+ if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
249
+ converter._experimental_qdq_conversion_mode = "DYNAMIC"
250
+ elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
251
+ converter._experimental_qdq_conversion_mode = "STATIC"
252
+ elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_DYNAMIC:
253
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
254
+ elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_FP16:
255
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
256
+ converter.target_spec.supported_types = [tf.float16]
257
+
258
+
259
+ def convert_stablehlo_to_tflite(
260
+ shlo_graph_module: stablehlo.StableHLOGraphModule,
261
+ signatures: list[Signature],
262
+ *,
263
+ quant_config: Optional[qcfg.QuantConfig] = None,
264
+ _tfl_converter_flags: dict = {},
265
+ ) -> None:
266
+ """Converts a StableHLOGraphModule to a tflite model.
267
+ Args:
268
+ shlo_graph_module - model to export and save
269
+ signatures: List of signatures from which names of the signatures is extracted.
270
+ quant_config: User-defined quantization method and scheme of the model.
271
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
272
+ """
273
+
274
+ bundle = shlo_graph_module._bundle
275
+ tf_module = tf.Module()
276
+ bundle.state_dict = {
277
+ k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
278
+ }
279
+ bundle.additional_constants = [
280
+ tf.Variable(v, trainable=False) for v in bundle.additional_constants
281
+ ]
282
+ tf_signatures: list[list[tf.TensorSpec]] = list(
283
+ _make_tf_signature(func.meta) for func in bundle.stablehlo_funcs
284
+ )
285
+
286
+ tf_functions = _make_tf_function(shlo_graph_module, bundle)
287
+
288
+ tf_module.f = []
289
+ for tf_sig, func in zip(tf_signatures, tf_functions):
290
+ tf_module.f.append(
291
+ tf.function(
292
+ func,
293
+ input_signature=tf_sig,
294
+ )
295
+ )
296
+
297
+ tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
298
+ del bundle
299
+ gc.collect()
300
+
301
+ tf_concrete_funcs = [
302
+ func.get_concrete_function(*tf_sig)
303
+ for func, tf_sig in zip(tf_module.f, tf_signatures)
304
+ ]
305
+
306
+ # We need to temporarily save since TFLite's from_concrete_functions does not
307
+ # allow providing names for each of the concrete functions.
308
+ with tempfile.TemporaryDirectory() as temp_dir_path:
309
+ tf.saved_model.save(
310
+ tf_module,
311
+ temp_dir_path,
312
+ signatures={
313
+ sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
314
+ },
315
+ )
316
+ # Clean up intermediate memory early.
317
+ del tf_module
318
+ del tf_concrete_funcs
319
+ gc.collect()
320
+
321
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
322
+ converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
323
+ converter._experimental_enable_composite_direct_lowering = True
324
+
325
+ _set_tfl_converter_quant_flags(converter, quant_config)
326
+ _apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
327
+
328
+ tflite_model = converter.convert()
329
+
330
+ return tflite_model
@@ -0,0 +1,171 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Dict, Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ from ai_edge_torch import model
23
+ from ai_edge_torch.convert import conversion
24
+ from ai_edge_torch.convert import conversion_utils as cutils
25
+ from ai_edge_torch.quantize import quant_config as qcfg
26
+
27
+
28
+ class Converter:
29
+
30
+ def __init__(self):
31
+ self._signatures: list[cutils.Signature] = []
32
+
33
+ def signature(
34
+ self,
35
+ name: str,
36
+ module: torch.nn.Module,
37
+ sample_args: tuple[cutils.TracingArg],
38
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
39
+ ) -> Converter:
40
+ """Alias to `add_signature`"""
41
+ return self.add_signature(name, module, sample_args, dynamic_shapes)
42
+
43
+ def add_signature(
44
+ self,
45
+ name: str,
46
+ module: torch.nn.Module,
47
+ sample_args: tuple[cutils.TracingArg],
48
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
49
+ ) -> Converter:
50
+ """Allows adding a new named torch model along with sample args to the conversion.
51
+
52
+ Args:
53
+ name: The name of the signature included in the converted edge model.
54
+ module: The torch module to be converted.
55
+ sample_args: Tuple of args by which the torch module will be traced prior to conversion.
56
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
57
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
58
+
59
+ Raises:
60
+ ValueError: If a signature with the provided name already exists.
61
+ """
62
+
63
+ if name in [sig.name for sig in self._signatures]:
64
+ raise ValueError(f"A signature with the provided name ({name}) is already added.")
65
+
66
+ self._signatures.append(cutils.Signature(name, module, sample_args, dynamic_shapes))
67
+ return self
68
+
69
+ def convert(
70
+ self,
71
+ module: torch.nn.Module = None,
72
+ sample_args: tuple[cutils.TracingArg] = None,
73
+ *,
74
+ quant_config: Optional[qcfg.QuantConfig] = None,
75
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
76
+ _ai_edge_converter_flags: dict = {},
77
+ ) -> model.TfLiteModel:
78
+ """Finalizes the conversion and produces an edge model.
79
+
80
+ This could be called with no arguments as follows:
81
+
82
+ edge_model = Converter().signature(name, module, args).convert()
83
+
84
+ Or it could be used to set the default signature for the converted edge model:
85
+
86
+ edge_model = Converter().convert(module, args)
87
+
88
+ Args:
89
+ name: The name of the signature included in the converted edge model.
90
+ module: The torch module to be converted.
91
+ sample_args: Tuple of args by which the torch module will be traced prior to conversion.
92
+ quant_config: User-defined quantization method and scheme of the model.
93
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
94
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
95
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
96
+ This gives access to an implementation detail of this function and so needs to be treated as such.
97
+ Please do not rely on this parameter except for local debugging as this can be removed in a future release.
98
+
99
+ Raises:
100
+ ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
101
+ """
102
+ if module is not None:
103
+ if sample_args is not None: # both module and args provided
104
+ self.add_signature(
105
+ cutils.DEFAULT_SIGNATURE_NAME, module, sample_args, dynamic_shapes
106
+ )
107
+ else: # module is provided but not sample_args
108
+ raise ValueError("sample_args needs to be provided if a module is specified.")
109
+
110
+ return conversion.convert_signatures(
111
+ self._signatures,
112
+ quant_config=quant_config,
113
+ _tfl_converter_flags=_ai_edge_converter_flags,
114
+ )
115
+
116
+
117
+ def signature(
118
+ name: str,
119
+ module: torch.nn.Module,
120
+ sample_args: tuple[cutils.TracingArg],
121
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
122
+ ) -> Converter:
123
+ """Initiates a Converter object with the provided signature.
124
+
125
+ Args:
126
+ name: The name of the signature included in the converted edge model.
127
+ module: The torch module to be converted.
128
+ sample_args: Tuple of args by which the torch module will be traced prior to conversion.
129
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
130
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
131
+
132
+ Example:
133
+ converter = ai_edge_torch.signature(name, module, args)
134
+ edge_model = converter.convert()
135
+
136
+ """
137
+ return Converter().signature(name, module, sample_args, dynamic_shapes)
138
+
139
+
140
+ def convert(
141
+ module: torch.nn.Module = None,
142
+ sample_args: tuple[cutils.TracingArg] = None,
143
+ *,
144
+ quant_config: Optional[qcfg.QuantConfig] = None,
145
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
146
+ _ai_edge_converter_flags: dict = {},
147
+ ) -> model.TfLiteModel:
148
+ """Allows converting a PyTorch model to an edge model with one default signature in one step.
149
+
150
+ Args:
151
+ module: The torch module to be converted.
152
+ sample_args: Tuple of args by which the torch module will be traced prior to conversion.
153
+ quant_config: User-defined quantization method and scheme of the model.
154
+ dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
155
+ See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
156
+ _ai_edge_converter_flags: A nested dictionary allowing setting flags for the underlying converter.
157
+ This gives access to an implementation detail of this function and so needs to be treated as such.
158
+ Please do not rely on this parameter except for local debugging as this can be removed in a future release.
159
+
160
+ Example:
161
+ edge_model = ai_edge_torch.convert(module, args)
162
+
163
+ """
164
+
165
+ return Converter().convert(
166
+ module,
167
+ sample_args,
168
+ quant_config=quant_config,
169
+ dynamic_shapes=dynamic_shapes,
170
+ _ai_edge_converter_flags=_ai_edge_converter_flags,
171
+ )