ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,31 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from .convert.converter import convert
17
+ from .convert.converter import signature
18
+ from .convert.to_channel_last_io import to_channel_last_io
19
+ from .model import Model
20
+
21
+
22
+ def load(path: str) -> Model:
23
+ """Imports an ai_edge_torch model from disk.
24
+
25
+ Args:
26
+ path: The path to the serialized ai_edge_torch model.
27
+
28
+ Returns:
29
+ An ai_edge_torch.model.Model object.
30
+ """
31
+ return Model.load(path)
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import gc
17
+ import logging
18
+ import os
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from torch.export import ExportedProgram
23
+ from torch_xla import stablehlo
24
+
25
+ from ai_edge_torch import model
26
+ from ai_edge_torch.convert import conversion_utils as cutils
27
+ from ai_edge_torch.convert.fx_passes import BuildAtenCompositePass
28
+ from ai_edge_torch.convert.fx_passes import BuildInterpolateCompositePass # NOQA
29
+ from ai_edge_torch.convert.fx_passes import CanonicalizePass
30
+ from ai_edge_torch.convert.fx_passes import InjectMlirDebuginfoPass
31
+ from ai_edge_torch.convert.fx_passes import OptimizeLayoutTransposesPass
32
+ from ai_edge_torch.convert.fx_passes import run_passes
33
+ from ai_edge_torch.generative.fx_passes import run_generative_passes
34
+ from ai_edge_torch.quantize import quant_config as qcfg
35
+
36
+ os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
37
+
38
+
39
+ def _run_convert_passes(
40
+ exported_program: ExportedProgram,
41
+ ) -> ExportedProgram:
42
+ exported_program = run_generative_passes(exported_program)
43
+ return run_passes(
44
+ exported_program,
45
+ [
46
+ BuildInterpolateCompositePass(),
47
+ CanonicalizePass(),
48
+ OptimizeLayoutTransposesPass(),
49
+ CanonicalizePass(),
50
+ BuildAtenCompositePass(),
51
+ CanonicalizePass(),
52
+ InjectMlirDebuginfoPass(),
53
+ CanonicalizePass(),
54
+ ],
55
+ )
56
+
57
+
58
+ def _warn_training_modules(signatures: list[cutils.Signature]):
59
+ for sig in signatures:
60
+ if not sig.module.training:
61
+ continue
62
+
63
+ message = (
64
+ "Your model {sig_name}is converted in training mode. "
65
+ "Please set the module in evaluation mode with `module.eval()` for better on-device performance and compatibility."
66
+ )
67
+ if len(signatures) == 1 and sig.name == cutils.DEFAULT_SIGNATURE_NAME:
68
+ # User does not specify any signature names explicitly.
69
+ message = message.format(sig_name="")
70
+ else:
71
+ message = message.format(sig_name=f'"{sig.name}" ')
72
+
73
+ logging.warn(message)
74
+
75
+
76
+ def convert_signatures(
77
+ signatures: list[cutils.Signature],
78
+ *,
79
+ quant_config: Optional[qcfg.QuantConfig] = None,
80
+ _tfl_converter_flags: dict = {},
81
+ ) -> model.TfLiteModel:
82
+ """Converts a list of `Signature`s and embeds them into one `model.TfLiteModel`.
83
+ Args:
84
+ signatures: The list of 'Signature' objects containing PyTorch modules to be converted.
85
+ quant_config: User-defined quantization method and scheme of the model.
86
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
87
+ """
88
+ _warn_training_modules(signatures)
89
+
90
+ exported_programs: torch.export.ExportedProgram = [
91
+ torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
92
+ for sig in signatures
93
+ ]
94
+
95
+ # Apply default fx passes
96
+ exported_programs = list(map(_run_convert_passes, exported_programs))
97
+ shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
98
+ cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
99
+ for exported, sig in zip(exported_programs, signatures)
100
+ ]
101
+
102
+ merged_shlo_graph_module: stablehlo.StableHLOGraphModule = (
103
+ cutils.merge_stablehlo_bundles(shlo_bundles, signatures, exported_programs)
104
+ )
105
+ del exported_programs
106
+ del shlo_bundles
107
+
108
+ gc.collect()
109
+
110
+ tflite_model = cutils.convert_stablehlo_to_tflite(
111
+ merged_shlo_graph_module,
112
+ signatures,
113
+ quant_config=quant_config,
114
+ _tfl_converter_flags=_tfl_converter_flags,
115
+ )
116
+
117
+ return model.TfLiteModel(tflite_model)
@@ -0,0 +1,400 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import collections
17
+ import copy
18
+ from dataclasses import dataclass
19
+ import gc
20
+ import itertools
21
+ import logging
22
+ import tempfile
23
+ from typing import Any, Dict, Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.utils._pytree as pytree
27
+ from torch_xla import stablehlo
28
+
29
+ from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
30
+ from ai_edge_torch.quantize import quant_config as qcfg
31
+
32
+ try:
33
+ import tensorflow as tf
34
+ from tensorflow.compiler.tf2xla.python import xla as tfxla
35
+
36
+ from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
37
+ except ImportError:
38
+ logging.error(
39
+ "This module needs tensorflow with xla support.\n"
40
+ "Please install tensorflow with `pip install tf-nightly`.\n"
41
+ )
42
+ raise
43
+
44
+ DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
45
+
46
+
47
+ @dataclass
48
+ class Signature:
49
+ name: str
50
+ module: torch.nn.Module
51
+ sample_args: tuple[torch.Tensor]
52
+ sample_kwargs: dict[str, torch.Tensor]
53
+ dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
54
+
55
+ @property
56
+ def _normalized_sample_args_kwargs(self):
57
+ args, kwargs = self.sample_args, self.sample_kwargs
58
+ if args is not None:
59
+ if not isinstance(args, tuple):
60
+ # TODO(b/352584188): Check value types
61
+ raise ValueError("sample_args must be a tuple of torch tensors.")
62
+ if kwargs is not None:
63
+ if not isinstance(kwargs, dict) or not all(
64
+ isinstance(key, str) for key in kwargs.keys()
65
+ ):
66
+ # TODO(b/352584188): Check value types
67
+ raise ValueError("sample_kwargs must be a dict of string to tensor.")
68
+
69
+ args = args if args is not None else tuple()
70
+ kwargs = kwargs if kwargs is not None else {}
71
+ return args, kwargs
72
+
73
+ @property
74
+ def flat_arg_names(self) -> list[str]:
75
+ spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
76
+ args_spec, kwargs_spec = spec.children_specs
77
+
78
+ names = []
79
+ for i in range(args_spec.num_leaves):
80
+ names.append(f"args_{i}")
81
+
82
+ dict_context = (
83
+ kwargs_spec.context
84
+ if kwargs_spec.type is not collections.defaultdict
85
+ # ignore mismatch of `default_factory` for defaultdict
86
+ else kwargs_spec.context[1]
87
+ )
88
+
89
+ for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
90
+ if value_spec.num_leaves == 1:
91
+ names.append(name)
92
+ else:
93
+ # value_spec.num_leaves may be greater than 1 when the value is a (nested)
94
+ # tuple of tensors. We haven't decided how we should support flattenable
95
+ # tensor containers as inputs.
96
+ # TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
97
+ for i in range(value_spec.num_leaves):
98
+ names.append(f"{name}_{i}")
99
+ return names
100
+
101
+ @property
102
+ def flat_args(self) -> tuple[torch.Tensor]:
103
+ return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
104
+
105
+
106
+ def exported_program_to_stablehlo_bundle(
107
+ exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
108
+ ) -> stablehlo.StableHLOModelBundle:
109
+ # Setting export_weights to False here so that pytorch/xla avoids copying the weights
110
+ # to a numpy array which would lead to memory bloat. This means that the state_dict
111
+ # in the returned bundle is going to be empty.
112
+ return stablehlo.exported_program_to_stablehlo(
113
+ exported_program,
114
+ stablehlo.StableHLOExportOptions(
115
+ override_tracing_arguments=sample_args, export_weights=False
116
+ ),
117
+ )._bundle
118
+
119
+
120
+ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
121
+ if not torch_tensor.is_contiguous():
122
+ torch_tensor = torch_tensor.contiguous()
123
+
124
+ try:
125
+ dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
126
+ tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
127
+ except Exception:
128
+ logging.info("Can not use dlpack to convert torch tensors. Falling back to numpy.")
129
+ nparray = torch_tensor.cpu().detach().numpy()
130
+ tf_tensor = tf.convert_to_tensor(nparray)
131
+
132
+ return tf_tensor
133
+
134
+
135
+ def _get_states(
136
+ exported_programs: list[torch.export.ExportedProgram], signatures: list[Signature]
137
+ ):
138
+ for exported_program, signature in zip(exported_programs, signatures):
139
+ args, _ = exported_program.example_inputs
140
+ # Calling this to get **all** the state including model buffers.
141
+ _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
142
+ for tensor, input_spec in zip(
143
+ _flat_input_args, exported_program.graph_signature.input_specs
144
+ ):
145
+ # Only interested in Tensors that are part of the state (and not user input).
146
+ if (
147
+ not isinstance(tensor, torch.Tensor)
148
+ or input_spec.kind == torch.export.graph_signature.InputKind.USER_INPUT
149
+ ):
150
+ continue
151
+ yield signature, tensor, input_spec
152
+
153
+
154
+ def _tensor_unique_id(tensor: torch.Tensor):
155
+ return (
156
+ str(tensor.device),
157
+ tensor.shape,
158
+ tensor.stride(),
159
+ tensor.untyped_storage().data_ptr(),
160
+ )
161
+
162
+
163
+ def _gather_state_dict(
164
+ exported_programs: list[torch.export.ExportedProgram],
165
+ signatures: list[Signature],
166
+ ):
167
+ deduped_tensor_map = {}
168
+
169
+ for _, tensor, _ in _get_states(exported_programs, signatures):
170
+ unique_id = _tensor_unique_id(tensor)
171
+ deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
172
+
173
+ state_dict = {}
174
+ for signature, tensor, input_spec in _get_states(exported_programs, signatures):
175
+ unique_id = _tensor_unique_id(tensor)
176
+ state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[unique_id]
177
+
178
+ return state_dict
179
+
180
+
181
+ def merge_stablehlo_bundles(
182
+ bundles: list[stablehlo.StableHLOModelBundle],
183
+ signatures: list[Signature],
184
+ exported_programs: list[torch.export.ExportedProgram],
185
+ ) -> stablehlo.StableHLOGraphModule:
186
+ state_dict = _gather_state_dict(exported_programs, signatures)
187
+
188
+ new_bundle = stablehlo.StableHLOModelBundle(
189
+ state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
190
+ )
191
+
192
+ for bundle, signature in zip(bundles, signatures):
193
+ const_offset = len(new_bundle.additional_constants)
194
+ for func in bundle.stablehlo_funcs:
195
+ func.meta.name = signature.name + "_" + func.meta.name
196
+ for loc in func.meta.input_locations:
197
+ if loc.type_ == stablehlo.VariableType.CONSTANT:
198
+ loc.position += const_offset
199
+ elif loc.type_ == stablehlo.VariableType.PARAMETER:
200
+ loc.name = signature.name + "_" + loc.name
201
+ new_bundle.stablehlo_funcs.append(func)
202
+ new_bundle.additional_constants.extend(bundle.additional_constants)
203
+ return stablehlo.StableHLOGraphModule(new_bundle)
204
+
205
+
206
+ def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
207
+ shape = copy.copy(signature.shape)
208
+ for i in signature.dynamic_dims:
209
+ shape[i] = None
210
+ return shape
211
+
212
+
213
+ def _wrap_as_tf_func(
214
+ func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
215
+ ):
216
+ def inner(*args):
217
+ type_info = [sig.dtype for sig in func.meta.output_signature]
218
+ shape_info = [_get_shape_with_dynamic(sig) for sig in func.meta.output_signature]
219
+ call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
220
+ return tfxla.call_module(
221
+ tuple(call_args),
222
+ version=5,
223
+ Tout=type_info,
224
+ Sout=shape_info,
225
+ function_list=[],
226
+ module=func.bytecode,
227
+ )
228
+
229
+ return inner
230
+
231
+
232
+ def _make_tf_function(
233
+ shlo_graph_module: stablehlo.StableHLOGraphModule,
234
+ bundle: stablehlo.StableHLOModelBundle = None,
235
+ ):
236
+ bundle = shlo_graph_module._bundle if bundle is None else bundle
237
+ return [
238
+ _wrap_as_tf_func(func, bundle)
239
+ for func in shlo_graph_module._bundle.stablehlo_funcs
240
+ ]
241
+
242
+
243
+ def _make_tf_signature(
244
+ meta: stablehlo.StableHLOFunctionMeta,
245
+ signature: Signature,
246
+ ) -> list[tf.TensorSpec]:
247
+ input_names = signature.flat_arg_names
248
+ input_pos_to_spec = {
249
+ loc.position: spec
250
+ for loc, spec in itertools.chain(
251
+ zip(meta.input_locations, meta.input_signature), meta.unused_inputs
252
+ )
253
+ if loc.type_ == stablehlo.VariableType.INPUT_ARG
254
+ }
255
+ assert len(input_pos_to_spec) == len(input_names)
256
+
257
+ primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
258
+ ret: list[tf.TensorSpec] = []
259
+ for i, name in enumerate(input_names):
260
+ spec = input_pos_to_spec[i]
261
+ shape = _get_shape_with_dynamic(spec)
262
+ ret.append(
263
+ tf.TensorSpec(
264
+ shape=shape,
265
+ dtype=primitive_type_to_tf_type[spec.dtype]
266
+ if spec.dtype in primitive_type_to_tf_type
267
+ else spec.dtype,
268
+ name=name,
269
+ )
270
+ )
271
+ return ret
272
+
273
+
274
+ def _apply_tfl_backdoor_flags(
275
+ converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict
276
+ ):
277
+ def _set_converter_flag(path: list):
278
+ if len(path) < 2:
279
+ raise ValueError("Expecting at least two values in the path.")
280
+
281
+ target_obj = converter
282
+ for idx in range(len(path) - 2):
283
+ target_obj = getattr(target_obj, path[idx])
284
+
285
+ setattr(target_obj, path[-2], path[-1])
286
+
287
+ def _iterate_dict_tree(flags_dict: dict, path: list):
288
+ for key, value in flags_dict.items():
289
+ path.append(key)
290
+ if isinstance(value, dict):
291
+ _iterate_dict_tree(value, path)
292
+ else:
293
+ path.append(value)
294
+ _set_converter_flag(path)
295
+ path.pop()
296
+ path.pop()
297
+
298
+ _iterate_dict_tree(tfl_converter_flags, [])
299
+
300
+
301
+ def _set_tfl_converter_quant_flags(
302
+ converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
303
+ ):
304
+ if quant_config is not None:
305
+ quantizer_mode = quant_config._quantizer_mode
306
+ if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
307
+ converter._experimental_qdq_conversion_mode = "DYNAMIC"
308
+ elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
309
+ converter._experimental_qdq_conversion_mode = "STATIC"
310
+
311
+
312
+ def convert_stablehlo_to_tflite(
313
+ shlo_graph_module: stablehlo.StableHLOGraphModule,
314
+ signatures: list[Signature],
315
+ *,
316
+ quant_config: Optional[qcfg.QuantConfig] = None,
317
+ _tfl_converter_flags: dict = {},
318
+ ) -> None:
319
+ """Converts a StableHLOGraphModule to a tflite model.
320
+ Args:
321
+ shlo_graph_module - model to export and save
322
+ signatures: List of signatures from which names of the signatures is extracted.
323
+ quant_config: User-defined quantization method and scheme of the model.
324
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
325
+ """
326
+
327
+ bundle = shlo_graph_module._bundle
328
+ tf_module = tf.Module()
329
+ bundle.state_dict = {
330
+ k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
331
+ }
332
+ bundle.additional_constants = [
333
+ tf.Variable(v, trainable=False) for v in bundle.additional_constants
334
+ ]
335
+ tf_signatures: list[list[tf.TensorSpec]] = list(
336
+ _make_tf_signature(func.meta, sig)
337
+ for func, sig in zip(bundle.stablehlo_funcs, signatures)
338
+ )
339
+
340
+ tf_functions = _make_tf_function(shlo_graph_module, bundle)
341
+
342
+ tf_module.f = []
343
+ for tf_sig, func in zip(tf_signatures, tf_functions):
344
+ tf_module.f.append(
345
+ tf.function(
346
+ func,
347
+ input_signature=tf_sig,
348
+ )
349
+ )
350
+
351
+ tf_module._variables = list(bundle.state_dict.values()) + bundle.additional_constants
352
+ del bundle
353
+ gc.collect()
354
+
355
+ tf_concrete_funcs = [
356
+ func.get_concrete_function(*tf_sig)
357
+ for func, tf_sig in zip(tf_module.f, tf_signatures)
358
+ ]
359
+
360
+ # We need to temporarily save since TFLite's from_concrete_functions does not
361
+ # allow providing names for each of the concrete functions.
362
+ with tempfile.TemporaryDirectory() as temp_dir_path:
363
+ tf.saved_model.save(
364
+ tf_module,
365
+ temp_dir_path,
366
+ signatures={
367
+ sig.name: tf_concrete_funcs[idx] for idx, sig in enumerate(signatures)
368
+ },
369
+ )
370
+ # Clean up intermediate memory early.
371
+ del tf_module
372
+ del tf_concrete_funcs
373
+ gc.collect()
374
+
375
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
376
+ converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
377
+ converter._experimental_enable_composite_direct_lowering = True
378
+
379
+ _set_tfl_converter_quant_flags(converter, quant_config)
380
+ if (
381
+ quant_config is not None
382
+ and quant_config._quantizer_mode
383
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
384
+ ):
385
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
386
+ quant_config.generative_recipe
387
+ )
388
+
389
+ _apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
390
+
391
+ tflite_model = converter.convert()
392
+
393
+ if (
394
+ quant_config is not None
395
+ and quant_config._quantizer_mode
396
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
397
+ ):
398
+ tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)
399
+
400
+ return tflite_model