ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import logging
17
+
18
+ from ai_edge_torch._convert import signature as signature_module
19
+ import tensorflow as tf
20
+ import torch
21
+
22
+
23
+ def _torch_to_tf_variable(torch_tensor: torch.Tensor):
24
+ if not torch_tensor.is_contiguous():
25
+ torch_tensor = torch_tensor.contiguous()
26
+
27
+ try:
28
+ dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
29
+ tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
30
+ except Exception:
31
+ logging.info(
32
+ "Can not use dlpack to convert torch tensors. Falling back to numpy."
33
+ )
34
+ nparray = torch_tensor.cpu().detach().numpy()
35
+ tf_tensor = tf.convert_to_tensor(nparray)
36
+
37
+ return tf.Variable(tf_tensor, trainable=False)
38
+
39
+
40
+ def _get_states(
41
+ exported_programs: list[torch.export.ExportedProgram],
42
+ signatures: list[signature_module.Signature],
43
+ ):
44
+ for exported_program, signature in zip(exported_programs, signatures):
45
+ args, _ = exported_program.example_inputs
46
+ # Calling this to get **all** the state including model buffers.
47
+ _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
48
+ for tensor, input_spec in zip(
49
+ _flat_input_args, exported_program.graph_signature.input_specs
50
+ ):
51
+ # Only interested in Tensors that are part of the state (and not user input).
52
+ if (
53
+ not isinstance(tensor, torch.Tensor)
54
+ or input_spec.kind
55
+ == torch.export.graph_signature.InputKind.USER_INPUT
56
+ ):
57
+ continue
58
+ yield signature, tensor, input_spec
59
+
60
+
61
+ def _tensor_unique_id(tensor: torch.Tensor):
62
+ return (
63
+ str(tensor.device),
64
+ tensor.shape,
65
+ tensor.stride(),
66
+ tensor.untyped_storage().data_ptr(),
67
+ )
68
+
69
+
70
+ def gather_state_dict(
71
+ exported_programs: list[torch.export.ExportedProgram],
72
+ signatures: list[signature_module.Signature],
73
+ ):
74
+ deduped_tensor_map = {}
75
+
76
+ for _, tensor, _ in _get_states(exported_programs, signatures):
77
+ unique_id = _tensor_unique_id(tensor)
78
+ deduped_tensor_map[unique_id] = _torch_to_tf_variable(tensor)
79
+
80
+ state_dict = {}
81
+ for signature, tensor, input_spec in _get_states(
82
+ exported_programs, signatures
83
+ ):
84
+ unique_id = _tensor_unique_id(tensor)
85
+ state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
86
+ unique_id
87
+ ]
88
+
89
+ return state_dict, list(deduped_tensor_map.values())
@@ -0,0 +1,211 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import dataclasses
17
+ import tempfile
18
+ from typing import Any, Dict, List, Optional, Tuple
19
+
20
+ from ai_edge_torch import odml_torch
21
+ from ai_edge_torch._convert import conversion_utils
22
+ from ai_edge_torch._convert import signature as signature_module
23
+ from ai_edge_torch.lowertools import common_utils
24
+ from ai_edge_torch.odml_torch import export
25
+ from ai_edge_torch.odml_torch import export_utils
26
+ from ai_edge_torch.quantize import quant_config as qcfg
27
+ import tensorflow as tf
28
+ import torch
29
+
30
+ from tensorflow.compiler.tf2xla.python import xla as tfxla
31
+
32
+ MlirBundle = odml_torch.export.MlirLowered
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class MergedBundle:
37
+ """A bundle of MlirLowered that has been merged."""
38
+
39
+ bundles: list[odml_torch.export.MlirLowered]
40
+ deduped_tf_vars: list[tf.Variable]
41
+
42
+
43
+ def torch_dtype_to_tf(dtype):
44
+ return {
45
+ torch.double: tf.float64,
46
+ torch.float32: tf.float32,
47
+ torch.half: tf.float16,
48
+ torch.long: tf.int64,
49
+ torch.int32: tf.int32,
50
+ torch.int16: tf.int16,
51
+ torch.bool: tf.bool,
52
+ }.get(dtype)
53
+
54
+
55
+ def _get_shape_with_dynamic(signature: export.VariableSignature):
56
+ return [
57
+ None if export_utils.is_torch_dynamic(s) else s for s in signature.shape
58
+ ]
59
+
60
+
61
+ def _extract_call_args(
62
+ bundle: export.MlirLowered,
63
+ args: Tuple[Any],
64
+ tf_state_dict: Dict[str, tf.Variable],
65
+ ):
66
+ call_args = []
67
+ for sig in bundle.input_signature:
68
+ if sig.input_spec.is_user_input:
69
+ call_args.append(args[sig.input_spec.i])
70
+ elif sig.input_spec.is_parameter:
71
+ name = sig.input_spec.name
72
+ call_args.append(tf_state_dict[name])
73
+ return call_args
74
+
75
+
76
+ def _wrap_as_tf_func(bundle, tf_state_dict):
77
+ def inner(*args):
78
+ t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
79
+ s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
80
+ call_args = _extract_call_args(bundle, args, tf_state_dict)
81
+ return tfxla.call_module(
82
+ tuple(call_args),
83
+ version=5,
84
+ Tout=t_outs, # dtype information
85
+ Sout=s_outs, # Shape information
86
+ function_list=[],
87
+ module=bundle.module_bytecode,
88
+ )
89
+
90
+ return inner
91
+
92
+
93
+ def _make_tf_signature(
94
+ input_signature: list[export.VariableSignature],
95
+ signature: signature_module.Signature,
96
+ ) -> List[tf.TensorSpec]:
97
+ input_names = signature.flat_arg_names
98
+ user_input_signature = sorted(
99
+ [sig for sig in input_signature if sig.input_spec.is_user_input],
100
+ key=lambda sig: sig.input_spec.i,
101
+ )
102
+ tf_signature = []
103
+
104
+ for sig in user_input_signature:
105
+ shape = _get_shape_with_dynamic(sig)
106
+ tf_signature.append(
107
+ tf.TensorSpec(
108
+ shape=shape,
109
+ dtype=torch_dtype_to_tf(sig.dtype),
110
+ name=input_names[sig.input_spec.i],
111
+ )
112
+ )
113
+ return tf_signature
114
+
115
+
116
+ def merged_bundle_to_tfl_model(
117
+ merged_bundle: MergedBundle,
118
+ signatures: list[signature_module.Signature],
119
+ *,
120
+ quant_config: Optional[qcfg.QuantConfig] = None,
121
+ _tfl_converter_flags: dict = {},
122
+ ):
123
+ tf_state_dict = merged_bundle.bundles[0].state_dict
124
+
125
+ tf_signatures = [
126
+ _make_tf_signature(bundle.input_signature, sig)
127
+ for bundle, sig in zip(merged_bundle.bundles, signatures)
128
+ ]
129
+ tf_functions = [
130
+ _wrap_as_tf_func(bundle, tf_state_dict)
131
+ for bundle in merged_bundle.bundles
132
+ ]
133
+
134
+ tf_module = tf.Module()
135
+ tf_module.f = []
136
+
137
+ for tf_sig, func in zip(tf_signatures, tf_functions):
138
+ tf_module.f.append(
139
+ tf.function(
140
+ func,
141
+ input_signature=tf_sig,
142
+ )
143
+ )
144
+
145
+ tf_module._variables = merged_bundle.deduped_tf_vars
146
+
147
+ tf_concrete_funcs = [
148
+ func.get_concrete_function(*tf_sig)
149
+ for func, tf_sig in zip(tf_module.f, tf_signatures)
150
+ ]
151
+
152
+ # We need to temporarily save since TFLite's from_concrete_functions does not
153
+ # allow providing names for each of the concrete functions.
154
+ with tempfile.TemporaryDirectory() as temp_dir_path:
155
+ tf.saved_model.save(
156
+ tf_module,
157
+ temp_dir_path,
158
+ signatures={
159
+ sig.name: tf_concrete_funcs[idx]
160
+ for idx, sig in enumerate(signatures)
161
+ },
162
+ )
163
+
164
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
165
+ converter._experimental_enable_composite_direct_lowering = True
166
+
167
+ conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
168
+
169
+ tflite_model = converter.convert()
170
+
171
+ return tflite_model
172
+
173
+
174
+ def exported_program_to_mlir_text(
175
+ exported_program: torch.export.ExportedProgram,
176
+ ) -> str:
177
+ """Converts a ExportedProgram to a MLIR text."""
178
+ return odml_torch.export.exported_program_to_mlir(exported_program).get_text(
179
+ enable_debug_info=True
180
+ )
181
+
182
+
183
+ def exported_program_to_mlir(
184
+ exported_program: torch.export.ExportedProgram,
185
+ sample_args: tuple[torch.Tensor],
186
+ ) -> export.MlirLowered:
187
+ """Converts a ExportedProgram to a MlirLowered."""
188
+ return odml_torch.export.exported_program_to_mlir(exported_program)
189
+
190
+
191
+ def merge_mlir_bundles(
192
+ bundles: list[export.MlirLowered],
193
+ signatures: list[signature_module.Signature],
194
+ exported_programs: list[torch.export.ExportedProgram],
195
+ ) -> MergedBundle:
196
+ """Merges a list of MlirLowered into one."""
197
+ state_dict, deduped_vars = common_utils.gather_state_dict(
198
+ exported_programs, signatures
199
+ )
200
+
201
+ merged_bundle = MergedBundle(
202
+ bundles=bundles.copy(), deduped_tf_vars=deduped_vars
203
+ )
204
+ for bundle, signature in zip(merged_bundle.bundles, signatures):
205
+ bundle.state_dict = state_dict
206
+
207
+ for var_sig in bundle.input_signature:
208
+ if var_sig.input_spec.is_parameter:
209
+ var_sig.input_spec.name = signature.name + "_" + var_sig.input_spec.name
210
+
211
+ return merged_bundle
@@ -0,0 +1,273 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import copy
17
+ import dataclasses
18
+ from dataclasses import dataclass
19
+ import gc
20
+ import itertools
21
+ import logging
22
+ import tempfile
23
+ from typing import Any, Dict, Optional, Tuple, Union
24
+
25
+ from ai_edge_torch import model
26
+ from ai_edge_torch._convert import conversion_utils
27
+ from ai_edge_torch._convert import signature as signature_module
28
+ from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
29
+ from ai_edge_torch.lowertools import common_utils
30
+ from ai_edge_torch.quantize import quant_config as qcfg
31
+ import torch
32
+ from torch_xla import stablehlo
33
+
34
+ try:
35
+ import tensorflow as tf
36
+
37
+ from tensorflow.compiler.tf2xla.python import xla as tfxla
38
+
39
+ from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
40
+ except ImportError:
41
+ logging.error(
42
+ "This module needs tensorflow with xla support.\n"
43
+ "Please install tensorflow with `pip install tf-nightly`.\n"
44
+ )
45
+ raise
46
+
47
+ MlirBundle = stablehlo.StableHLOModelBundle
48
+
49
+
50
+ @dataclasses.dataclass
51
+ class MergedBundle:
52
+
53
+ bundle: stablehlo.StableHLOModelBundle
54
+ deduped_tf_vars: list[tf.Variable]
55
+
56
+
57
+ def exported_program_to_mlir(
58
+ exported_program: torch.export.ExportedProgram,
59
+ sample_args: tuple[torch.Tensor],
60
+ ) -> stablehlo.StableHLOModelBundle:
61
+ # Setting export_weights to False here so that pytorch/xla avoids copying the weights
62
+ # to a numpy array which would lead to memory bloat. This means that the state_dict
63
+ # in the returned bundle is going to be empty.
64
+ return stablehlo.exported_program_to_stablehlo(
65
+ exported_program,
66
+ stablehlo.StableHLOExportOptions(
67
+ override_tracing_arguments=sample_args, export_weights=False
68
+ ),
69
+ )._bundle
70
+
71
+
72
+ def merge_mlir_bundles(
73
+ bundles: list[stablehlo.StableHLOModelBundle],
74
+ signatures: list[signature_module.Signature],
75
+ exported_programs: list[torch.export.ExportedProgram],
76
+ ) -> stablehlo.StableHLOGraphModule:
77
+ state_dict, deduped_tf_vars = common_utils.gather_state_dict(
78
+ exported_programs, signatures
79
+ )
80
+
81
+ new_shlo_model_bundle = stablehlo.StableHLOModelBundle(
82
+ state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
83
+ )
84
+
85
+ for bundle, signature in zip(bundles, signatures):
86
+ const_offset = len(new_shlo_model_bundle.additional_constants)
87
+ for func in bundle.stablehlo_funcs:
88
+ func.meta.name = signature.name + "_" + func.meta.name
89
+ for loc in func.meta.input_locations:
90
+ if loc.type_ == stablehlo.VariableType.CONSTANT:
91
+ loc.position += const_offset
92
+ elif loc.type_ == stablehlo.VariableType.PARAMETER:
93
+ loc.name = signature.name + "_" + loc.name
94
+ new_shlo_model_bundle.stablehlo_funcs.append(func)
95
+ new_shlo_model_bundle.additional_constants.extend(
96
+ bundle.additional_constants
97
+ )
98
+ return MergedBundle(
99
+ bundle=new_shlo_model_bundle, deduped_tf_vars=deduped_tf_vars
100
+ )
101
+
102
+
103
+ def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
104
+ shape = copy.copy(signature.shape)
105
+ for i in signature.dynamic_dims:
106
+ shape[i] = None
107
+ return shape
108
+
109
+
110
+ def _wrap_as_tf_func(
111
+ func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
112
+ ):
113
+ def inner(*args):
114
+ type_info = [sig.dtype for sig in func.meta.output_signature]
115
+ shape_info = [
116
+ _get_shape_with_dynamic(sig) for sig in func.meta.output_signature
117
+ ]
118
+ call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
119
+ return tfxla.call_module(
120
+ tuple(call_args),
121
+ version=5,
122
+ Tout=type_info,
123
+ Sout=shape_info,
124
+ function_list=[],
125
+ module=func.bytecode,
126
+ )
127
+
128
+ return inner
129
+
130
+
131
+ def _make_tf_function(
132
+ bundle: stablehlo.StableHLOModelBundle = None,
133
+ ):
134
+ bundle = bundle if bundle is None else bundle
135
+ return [_wrap_as_tf_func(func, bundle) for func in bundle.stablehlo_funcs]
136
+
137
+
138
+ def _make_tf_signature(
139
+ meta: stablehlo.StableHLOFunctionMeta,
140
+ signature: signature_module.Signature,
141
+ ) -> list[tf.TensorSpec]:
142
+ input_names = signature.flat_arg_names
143
+ input_pos_to_spec = {
144
+ loc.position: spec
145
+ for loc, spec in itertools.chain(
146
+ zip(meta.input_locations, meta.input_signature), meta.unused_inputs
147
+ )
148
+ if loc.type_ == stablehlo.VariableType.INPUT_ARG
149
+ }
150
+ assert len(input_pos_to_spec) == len(input_names)
151
+
152
+ primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
153
+ ret: list[tf.TensorSpec] = []
154
+ for i, name in enumerate(input_names):
155
+ spec = input_pos_to_spec[i]
156
+ shape = _get_shape_with_dynamic(spec)
157
+ ret.append(
158
+ tf.TensorSpec(
159
+ shape=shape,
160
+ dtype=primitive_type_to_tf_type[spec.dtype]
161
+ if spec.dtype in primitive_type_to_tf_type
162
+ else spec.dtype,
163
+ name=name,
164
+ )
165
+ )
166
+ return ret
167
+
168
+
169
+ def exported_program_to_mlir_text(
170
+ exported_program: torch.export.ExportedProgram,
171
+ ) -> str:
172
+ """Converts a ExportedProgram to a MLIR text."""
173
+ return stablehlo.exported_program_to_stablehlo(
174
+ exported_program
175
+ ).get_stablehlo_text()
176
+
177
+
178
+ def merged_bundle_to_tfl_model(
179
+ merged_bundle: MergedBundle,
180
+ signatures: list[signature_module.Signature],
181
+ *,
182
+ quant_config: Optional[qcfg.QuantConfig] = None,
183
+ _tfl_converter_flags: dict = {},
184
+ ) -> None:
185
+ """Converts a StableHLOGraphModule to a tflite model.
186
+
187
+ Args: shlo_bundle - model to export and save
188
+
189
+ signatures: List of signatures from which names of the signatures is
190
+ extracted.
191
+ quant_config: User-defined quantization method and scheme of the model.
192
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the
193
+ underlying tflite converter.
194
+ """
195
+
196
+ tf_module = tf.Module()
197
+
198
+ shlo_bundle = merged_bundle.bundle
199
+
200
+ shlo_bundle.additional_constants = [
201
+ tf.Variable(v, trainable=False) for v in shlo_bundle.additional_constants
202
+ ]
203
+ tf_signatures: list[list[tf.TensorSpec]] = list(
204
+ _make_tf_signature(func.meta, sig)
205
+ for func, sig in zip(shlo_bundle.stablehlo_funcs, signatures)
206
+ )
207
+
208
+ tf_functions = _make_tf_function(shlo_bundle)
209
+
210
+ tf_module.f = []
211
+ for tf_sig, func in zip(tf_signatures, tf_functions):
212
+ tf_module.f.append(
213
+ tf.function(
214
+ func,
215
+ input_signature=tf_sig,
216
+ )
217
+ )
218
+
219
+ tf_module._variables = (
220
+ merged_bundle.deduped_tf_vars + shlo_bundle.additional_constants
221
+ )
222
+ del shlo_bundle
223
+ gc.collect()
224
+
225
+ tf_concrete_funcs = [
226
+ func.get_concrete_function(*tf_sig)
227
+ for func, tf_sig in zip(tf_module.f, tf_signatures)
228
+ ]
229
+
230
+ # We need to temporarily save since TFLite's from_concrete_functions does not
231
+ # allow providing names for each of the concrete functions.
232
+ with tempfile.TemporaryDirectory() as temp_dir_path:
233
+ tf.saved_model.save(
234
+ tf_module,
235
+ temp_dir_path,
236
+ signatures={
237
+ sig.name: tf_concrete_funcs[idx]
238
+ for idx, sig in enumerate(signatures)
239
+ },
240
+ )
241
+ # Clean up intermediate memory early.
242
+ del tf_module
243
+ del tf_concrete_funcs
244
+ gc.collect()
245
+
246
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
247
+ converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
248
+ converter._experimental_enable_composite_direct_lowering = True
249
+
250
+ conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
251
+ if (
252
+ quant_config is not None
253
+ and quant_config._quantizer_mode
254
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
255
+ ):
256
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
257
+ quant_config.generative_recipe
258
+ )
259
+
260
+ conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
261
+
262
+ tflite_model = converter.convert()
263
+
264
+ if (
265
+ quant_config is not None
266
+ and quant_config._quantizer_mode
267
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
268
+ ):
269
+ tflite_model = translate_recipe.quantize_model(
270
+ tflite_model, translated_recipe
271
+ )
272
+
273
+ return tflite_model
ai_edge_torch/model.py CHANGED
@@ -15,17 +15,18 @@
15
15
 
16
16
  """Represents an ai_edge_torch model.
17
17
 
18
- PyTorch models can be converted to this representation through `ai_edge_torch.convert`.
18
+ PyTorch models can be converted to this representation through
19
+ `ai_edge_torch.convert`.
19
20
  """
20
21
  from __future__ import annotations
21
22
 
22
23
  import abc
23
24
 
24
- from ai_edge_torch.convert import conversion_utils as cutils
25
- import numpy as np
26
25
  import numpy.typing as npt
27
26
  import tensorflow as tf
28
27
 
28
+ DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
29
+
29
30
 
30
31
  class Model(abc.ABC):
31
32
  """Represents and edge model."""
@@ -34,7 +35,7 @@ class Model(abc.ABC):
34
35
  def __call__(
35
36
  self,
36
37
  *args: npt.ArrayLike,
37
- signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
38
+ signature_name: str = DEFAULT_SIGNATURE_NAME,
38
39
  **kwargs,
39
40
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
40
41
  raise NotImplementedError()
@@ -66,18 +67,22 @@ class TfLiteModel(Model):
66
67
  def __call__(
67
68
  self,
68
69
  *args: npt.ArrayLike,
69
- signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
70
+ signature_name: str = DEFAULT_SIGNATURE_NAME,
70
71
  **kwargs,
71
72
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
72
73
  """Runs inference on the edge model using the provided arguments.
73
74
 
74
75
  Args:
75
76
  *args: The arguments to be passed to the model for inference.
76
- **kwargs: The arguments with specific names to be passed to the model for inference.
77
- signature_name: The name of the signature to be used for inference.
78
- The default signature is used if not provided.
77
+ **kwargs: The arguments with specific names to be passed to the model for
78
+ inference.
79
+ signature_name: The name of the signature to be used for inference. The
80
+ default signature is used if not provided.
79
81
  """
80
- interpreter = tf.lite.Interpreter(model_content=self._tflite_model)
82
+ interpreter = tf.lite.Interpreter(
83
+ model_content=self._tflite_model,
84
+ experimental_default_delegate_latest_features=True,
85
+ )
81
86
  interpreter.allocate_tensors()
82
87
 
83
88
  signature_list = interpreter.get_signature_list()
@@ -188,15 +188,18 @@ def _get_supported_config_and_operators() -> List[OperatorConfig]:
188
188
 
189
189
  def _get_module_name_filter(module_name: str):
190
190
  """Get the module_name_filter function for a given module name, the filter accepts
191
+
191
192
  a node and checks if the node comes from a module that has certain module name
192
193
 
193
194
  For example:
194
- node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
195
+ node: linear_op = call_function[...](...) # comes from a module with name
196
+ blocks.sub.linear1
195
197
 
196
198
 
197
199
  >> module_name_filter = _get_module_name_filter("blocks.sub")
198
200
  >> print(module_name_filter(node))
199
- True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
201
+ True # the node is from "blocks.sub" based on the fully qualified name
202
+ "blocks.sub.linear1"
200
203
  """
201
204
 
202
205
  def module_name_filter(n: Node) -> bool:
@@ -216,15 +219,19 @@ def _get_module_name_filter(module_name: str):
216
219
 
217
220
  def _get_module_type_filter(tp: Callable):
218
221
  """Get the module_type_filter function for a given module type, the filter accepts
222
+
219
223
  a node and checks if the node comes from a module that has certain module type
220
224
 
221
225
  For example:
222
- node: linear_op = call_function[...](...) # comes from a module with type Block -> Sub -> Linear
226
+ node: linear_op = call_function[...](...) # comes from a module with type
227
+ Block -> Sub -> Linear
223
228
 
224
229
 
225
- >> module_type_filter = _get_module_type_filter(Sub) # submodule with type `Sub`, under the `Block` submodule
230
+ >> module_type_filter = _get_module_type_filter(Sub) # submodule with type
231
+ `Sub`, under the `Block` submodule
226
232
  >> print(module_type_filter(node))
227
- True # the node is from the submodule `Sub` (same for `Block` and `Linear` as well)
233
+ True # the node is from the submodule `Sub` (same for `Block` and `Linear` as
234
+ well)
228
235
  """
229
236
 
230
237
  def module_type_filter(n: Node) -> bool:
@@ -338,8 +345,11 @@ class PT2EQuantizer(Quantizer):
338
345
  self, module_type: Callable, quantization_config: QuantizationConfig
339
346
  ):
340
347
  """Set quantization_config for a submodule with type: `module_type`, for example:
341
- quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it will quantize all supported operator/operator
342
- patterns in the submodule with this module type with the given `quantization_config`
348
+
349
+ quantizer.set_module_name(Sub) or quantizer.set_module_name(nn.Linear), it
350
+ will quantize all supported operator/operator
351
+ patterns in the submodule with this module type with the given
352
+ `quantization_config`
343
353
  """
344
354
  self.module_type_config[module_type] = quantization_config
345
355
  return self
@@ -348,8 +358,11 @@ class PT2EQuantizer(Quantizer):
348
358
  self, module_name: str, quantization_config: Optional[QuantizationConfig]
349
359
  ):
350
360
  """Set quantization_config for a submodule with name: `module_name`, for example:
351
- quantizer.set_module_name("blocks.sub"), it will quantize all supported operator/operator
352
- patterns in the submodule with this module name with the given `quantization_config`
361
+
362
+ quantizer.set_module_name("blocks.sub"), it will quantize all supported
363
+ operator/operator
364
+ patterns in the submodule with this module name with the given
365
+ `quantization_config`
353
366
  """
354
367
  assert (
355
368
  quantization_config is not None