ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240807__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (103) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +201 -0
  88. ai_edge_torch/{convert/conversion_utils.py → lowertools/torch_xla_utils.py} +35 -214
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240807.dist-info/RECORD +141 -0
  97. ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
  98. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  99. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  101. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/LICENSE +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/WHEEL +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240807.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,89 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import logging
17
+
18
+ from ai_edge_torch._convert import signature as signature_module
19
+ import tensorflow as tf
20
+ import torch
21
+
22
+
23
+ def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
24
+ if not torch_tensor.is_contiguous():
25
+ torch_tensor = torch_tensor.contiguous()
26
+
27
+ try:
28
+ dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
29
+ tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
30
+ except Exception:
31
+ logging.info(
32
+ "Can not use dlpack to convert torch tensors. Falling back to numpy."
33
+ )
34
+ nparray = torch_tensor.cpu().detach().numpy()
35
+ tf_tensor = tf.convert_to_tensor(nparray)
36
+
37
+ return tf_tensor
38
+
39
+
40
+ def _get_states(
41
+ exported_programs: list[torch.export.ExportedProgram],
42
+ signatures: list[signature_module.Signature],
43
+ ):
44
+ for exported_program, signature in zip(exported_programs, signatures):
45
+ args, _ = exported_program.example_inputs
46
+ # Calling this to get **all** the state including model buffers.
47
+ _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
48
+ for tensor, input_spec in zip(
49
+ _flat_input_args, exported_program.graph_signature.input_specs
50
+ ):
51
+ # Only interested in Tensors that are part of the state (and not user input).
52
+ if (
53
+ not isinstance(tensor, torch.Tensor)
54
+ or input_spec.kind
55
+ == torch.export.graph_signature.InputKind.USER_INPUT
56
+ ):
57
+ continue
58
+ yield signature, tensor, input_spec
59
+
60
+
61
+ def _tensor_unique_id(tensor: torch.Tensor):
62
+ return (
63
+ str(tensor.device),
64
+ tensor.shape,
65
+ tensor.stride(),
66
+ tensor.untyped_storage().data_ptr(),
67
+ )
68
+
69
+
70
+ def gather_state_dict(
71
+ exported_programs: list[torch.export.ExportedProgram],
72
+ signatures: list[signature_module.Signature],
73
+ ):
74
+ deduped_tensor_map = {}
75
+
76
+ for _, tensor, _ in _get_states(exported_programs, signatures):
77
+ unique_id = _tensor_unique_id(tensor)
78
+ deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
79
+
80
+ state_dict = {}
81
+ for signature, tensor, input_spec in _get_states(
82
+ exported_programs, signatures
83
+ ):
84
+ unique_id = _tensor_unique_id(tensor)
85
+ state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
86
+ unique_id
87
+ ]
88
+
89
+ return state_dict
@@ -0,0 +1,201 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import tempfile
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ from ai_edge_torch import odml_torch
20
+ from ai_edge_torch._convert import conversion_utils
21
+ from ai_edge_torch._convert import signature as signature_module
22
+ from ai_edge_torch.lowertools import common_utils
23
+ from ai_edge_torch.odml_torch import export
24
+ from ai_edge_torch.odml_torch import export_utils
25
+ from ai_edge_torch.quantize import quant_config as qcfg
26
+ import tensorflow as tf
27
+ import torch
28
+
29
+ from tensorflow.compiler.tf2xla.python import xla as tfxla
30
+
31
+ MlirBundle = odml_torch.export.MlirLowered
32
+ MergedBundle = list[odml_torch.export.MlirLowered]
33
+
34
+
35
+ def torch_dtype_to_tf(dtype):
36
+ return {
37
+ torch.double: tf.float64,
38
+ torch.float32: tf.float32,
39
+ torch.half: tf.float16,
40
+ torch.long: tf.int64,
41
+ torch.int32: tf.int32,
42
+ torch.int16: tf.int16,
43
+ torch.bool: tf.bool,
44
+ }.get(dtype)
45
+
46
+
47
+ def _get_shape_with_dynamic(signature: export.VariableSignature):
48
+ return [
49
+ None if export_utils.is_torch_dynamic(s) else s for s in signature.shape
50
+ ]
51
+
52
+
53
+ def _extract_call_args(
54
+ bundle: export.MlirLowered,
55
+ args: Tuple[Any],
56
+ tf_state_dict: Dict[str, tf.Variable],
57
+ ):
58
+ call_args = []
59
+ for sig in bundle.input_signature:
60
+ if sig.input_spec.is_user_input:
61
+ call_args.append(args[sig.input_spec.i])
62
+ elif sig.input_spec.is_parameter:
63
+ name = sig.input_spec.name
64
+ call_args.append(tf_state_dict[name])
65
+ return call_args
66
+
67
+
68
+ def _wrap_as_tf_func(bundle, tf_state_dict):
69
+ def inner(*args):
70
+ t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
71
+ s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
72
+ call_args = _extract_call_args(bundle, args, tf_state_dict)
73
+ return tfxla.call_module(
74
+ tuple(call_args),
75
+ version=5,
76
+ Tout=t_outs, # dtype information
77
+ Sout=s_outs, # Shape information
78
+ function_list=[],
79
+ module=bundle.module_bytecode,
80
+ )
81
+
82
+ return inner
83
+
84
+
85
+ def _make_tf_signature(
86
+ input_signature: list[export.VariableSignature],
87
+ signature: signature_module.Signature,
88
+ ) -> List[tf.TensorSpec]:
89
+ input_names = signature.flat_arg_names
90
+ user_input_signature = sorted(
91
+ [sig for sig in input_signature if sig.input_spec.is_user_input],
92
+ key=lambda sig: sig.input_spec.i,
93
+ )
94
+ tf_signature = []
95
+
96
+ for sig in user_input_signature:
97
+ shape = _get_shape_with_dynamic(sig)
98
+ tf_signature.append(
99
+ tf.TensorSpec(
100
+ shape=shape,
101
+ dtype=torch_dtype_to_tf(sig.dtype),
102
+ name=input_names[sig.input_spec.i],
103
+ )
104
+ )
105
+ return tf_signature
106
+
107
+
108
+ def merged_bundle_to_tfl_model(
109
+ merged_bundle: MergedBundle,
110
+ signatures: list[signature_module.Signature],
111
+ *,
112
+ quant_config: Optional[qcfg.QuantConfig] = None,
113
+ _tfl_converter_flags: dict = {},
114
+ ):
115
+ tf_state_dict = {
116
+ k: tf.Variable(v, trainable=False)
117
+ for k, v in merged_bundle[0].state_dict.items()
118
+ }
119
+
120
+ tf_signatures = [
121
+ _make_tf_signature(bundle.input_signature, sig)
122
+ for bundle, sig in zip(merged_bundle, signatures)
123
+ ]
124
+ tf_functions = [
125
+ _wrap_as_tf_func(bundle, tf_state_dict) for bundle in merged_bundle
126
+ ]
127
+
128
+ tf_module = tf.Module()
129
+ tf_module.f = []
130
+
131
+ for tf_sig, func in zip(tf_signatures, tf_functions):
132
+ tf_module.f.append(
133
+ tf.function(
134
+ func,
135
+ input_signature=tf_sig,
136
+ )
137
+ )
138
+
139
+ tf_module._variables = list(tf_state_dict.values())
140
+
141
+ tf_concrete_funcs = [
142
+ func.get_concrete_function(*tf_sig)
143
+ for func, tf_sig in zip(tf_module.f, tf_signatures)
144
+ ]
145
+
146
+ # We need to temporarily save since TFLite's from_concrete_functions does not
147
+ # allow providing names for each of the concrete functions.
148
+ with tempfile.TemporaryDirectory() as temp_dir_path:
149
+ tf.saved_model.save(
150
+ tf_module,
151
+ temp_dir_path,
152
+ signatures={
153
+ sig.name: tf_concrete_funcs[idx]
154
+ for idx, sig in enumerate(signatures)
155
+ },
156
+ )
157
+
158
+ converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
159
+ converter._experimental_enable_composite_direct_lowering = True
160
+
161
+ conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
162
+
163
+ tflite_model = converter.convert()
164
+
165
+ return tflite_model
166
+
167
+
168
+ def exported_program_to_mlir_text(
169
+ exported_program: torch.export.ExportedProgram,
170
+ ) -> str:
171
+ """Converts a ExportedProgram to a MLIR text."""
172
+ return odml_torch.export.exported_program_to_mlir(exported_program).get_text(
173
+ enable_debug_info=True
174
+ )
175
+
176
+
177
+ def exported_program_to_mlir(
178
+ exported_program: torch.export.ExportedProgram,
179
+ sample_args: tuple[torch.Tensor],
180
+ ) -> export.MlirLowered:
181
+ """Converts a ExportedProgram to a MlirLowered."""
182
+ return odml_torch.export.exported_program_to_mlir(exported_program)
183
+
184
+
185
+ def merge_mlir_bundles(
186
+ bundles: list[export.MlirLowered],
187
+ signatures: list[signature_module.Signature],
188
+ exported_programs: list[torch.export.ExportedProgram],
189
+ ) -> MergedBundle:
190
+ """Merges a list of MlirLowered into one."""
191
+ state_dict = common_utils.gather_state_dict(exported_programs, signatures)
192
+
193
+ merged_bundle = bundles.copy()
194
+ for bundle, signature in zip(merged_bundle, signatures):
195
+ bundle.state_dict = state_dict
196
+
197
+ for var_sig in bundle.input_signature:
198
+ if var_sig.input_spec.is_parameter:
199
+ var_sig.input_spec.name = signature.name + "_" + var_sig.input_spec.name
200
+
201
+ return merged_bundle
@@ -13,19 +13,21 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import collections
17
16
  import copy
18
17
  from dataclasses import dataclass
19
18
  import gc
20
19
  import itertools
21
20
  import logging
22
21
  import tempfile
23
- from typing import Any, Dict, List, Optional, Tuple, Union
22
+ from typing import Any, Dict, Optional, Tuple, Union
24
23
 
24
+ from ai_edge_torch import model
25
+ from ai_edge_torch._convert import conversion_utils
26
+ from ai_edge_torch._convert import signature as signature_module
25
27
  from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
28
+ from ai_edge_torch.lowertools import common_utils
26
29
  from ai_edge_torch.quantize import quant_config as qcfg
27
30
  import torch
28
- import torch.utils._pytree as pytree
29
31
  from torch_xla import stablehlo
30
32
 
31
33
  try:
@@ -41,92 +43,11 @@ except ImportError:
41
43
  )
42
44
  raise
43
45
 
44
- DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
45
-
46
-
47
- @dataclass
48
- class Signature:
49
- name: str
50
- module: torch.nn.Module
51
- sample_args: tuple[torch.Tensor]
52
- sample_kwargs: dict[str, torch.Tensor]
53
- dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
54
-
55
- @property
56
- def _normalized_sample_args_kwargs(self):
57
- args, kwargs = self.sample_args, self.sample_kwargs
58
- if args is not None:
59
- if not isinstance(args, tuple):
60
- # TODO(b/352584188): Check value types
61
- raise ValueError("sample_args must be a tuple of torch tensors.")
62
- if kwargs is not None:
63
- if not isinstance(kwargs, dict) or not all(
64
- isinstance(key, str) for key in kwargs.keys()
65
- ):
66
- # TODO(b/352584188): Check value types
67
- raise ValueError("sample_kwargs must be a dict of string to tensor.")
68
-
69
- args = args if args is not None else tuple()
70
- kwargs = kwargs if kwargs is not None else {}
71
- return args, kwargs
72
-
73
- @property
74
- def flat_arg_names(self) -> list[str]:
75
- spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
76
- args_spec, kwargs_spec = spec.children_specs
77
-
78
- names = []
79
- for i in range(args_spec.num_leaves):
80
- names.append(f"args_{i}")
81
-
82
- kwargs_names = self._flat_kwarg_names(
83
- kwargs_spec.children_specs, kwargs_spec.context
84
- )
85
- names.extend(kwargs_names)
86
- return names
87
-
88
- def _flat_kwarg_names(self, specs, context) -> List[str]:
89
- flat_names = []
90
- if context is None:
91
- for i, spec in enumerate(specs):
92
- if spec.children_specs:
93
- flat_names.extend([
94
- f"{i}_{name}"
95
- for name in self._flat_kwarg_names(
96
- spec.children_specs, spec.context
97
- )
98
- ])
99
- else:
100
- flat_names.append(f"{i}")
101
- else:
102
- flat_ctx = self._flatten_list(context)
103
- for prefix, spec in zip(flat_ctx, specs):
104
- leaf_flat_names = self._flat_kwarg_names(
105
- spec.children_specs, spec.context
106
- )
107
- if leaf_flat_names:
108
- flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
109
- else:
110
- flat_names.append(prefix)
111
-
112
- return flat_names
113
-
114
- def _flatten_list(self, l: List) -> List:
115
- flattened = []
116
- for item in l:
117
- if isinstance(item, list):
118
- flattened.extend(self._flatten_list(item))
119
- else:
120
- flattened.append(item)
121
- return flattened
122
-
123
- @property
124
- def flat_args(self) -> tuple[Any]:
125
- args, kwargs = self._normalized_sample_args_kwargs
126
- return tuple([*args, *kwargs.values()])
46
+ MlirBundle = stablehlo.StableHLOModelBundle
47
+ MergedBundle = stablehlo.StableHLOModelBundle
127
48
 
128
49
 
129
- def exported_program_to_stablehlo_bundle(
50
+ def exported_program_to_mlir(
130
51
  exported_program: torch.export.ExportedProgram,
131
52
  sample_args: tuple[torch.Tensor],
132
53
  ) -> stablehlo.StableHLOModelBundle:
@@ -141,81 +62,12 @@ def exported_program_to_stablehlo_bundle(
141
62
  )._bundle
142
63
 
143
64
 
144
- def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
145
- if not torch_tensor.is_contiguous():
146
- torch_tensor = torch_tensor.contiguous()
147
-
148
- try:
149
- dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
150
- tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
151
- except Exception:
152
- logging.info(
153
- "Can not use dlpack to convert torch tensors. Falling back to numpy."
154
- )
155
- nparray = torch_tensor.cpu().detach().numpy()
156
- tf_tensor = tf.convert_to_tensor(nparray)
157
-
158
- return tf_tensor
159
-
160
-
161
- def _get_states(
162
- exported_programs: list[torch.export.ExportedProgram],
163
- signatures: list[Signature],
164
- ):
165
- for exported_program, signature in zip(exported_programs, signatures):
166
- args, _ = exported_program.example_inputs
167
- # Calling this to get **all** the state including model buffers.
168
- _flat_input_args = exported_program._graph_module_flat_inputs(args, {})
169
- for tensor, input_spec in zip(
170
- _flat_input_args, exported_program.graph_signature.input_specs
171
- ):
172
- # Only interested in Tensors that are part of the state (and not user input).
173
- if (
174
- not isinstance(tensor, torch.Tensor)
175
- or input_spec.kind
176
- == torch.export.graph_signature.InputKind.USER_INPUT
177
- ):
178
- continue
179
- yield signature, tensor, input_spec
180
-
181
-
182
- def _tensor_unique_id(tensor: torch.Tensor):
183
- return (
184
- str(tensor.device),
185
- tensor.shape,
186
- tensor.stride(),
187
- tensor.untyped_storage().data_ptr(),
188
- )
189
-
190
-
191
- def _gather_state_dict(
192
- exported_programs: list[torch.export.ExportedProgram],
193
- signatures: list[Signature],
194
- ):
195
- deduped_tensor_map = {}
196
-
197
- for _, tensor, _ in _get_states(exported_programs, signatures):
198
- unique_id = _tensor_unique_id(tensor)
199
- deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
200
-
201
- state_dict = {}
202
- for signature, tensor, input_spec in _get_states(
203
- exported_programs, signatures
204
- ):
205
- unique_id = _tensor_unique_id(tensor)
206
- state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
207
- unique_id
208
- ]
209
-
210
- return state_dict
211
-
212
-
213
- def merge_stablehlo_bundles(
65
+ def merge_mlir_bundles(
214
66
  bundles: list[stablehlo.StableHLOModelBundle],
215
- signatures: list[Signature],
67
+ signatures: list[signature_module.Signature],
216
68
  exported_programs: list[torch.export.ExportedProgram],
217
69
  ) -> stablehlo.StableHLOGraphModule:
218
- state_dict = _gather_state_dict(exported_programs, signatures)
70
+ state_dict = common_utils.gather_state_dict(exported_programs, signatures)
219
71
 
220
72
  new_bundle = stablehlo.StableHLOModelBundle(
221
73
  state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
@@ -232,7 +84,7 @@ def merge_stablehlo_bundles(
232
84
  loc.name = signature.name + "_" + loc.name
233
85
  new_bundle.stablehlo_funcs.append(func)
234
86
  new_bundle.additional_constants.extend(bundle.additional_constants)
235
- return stablehlo.StableHLOGraphModule(new_bundle)
87
+ return new_bundle
236
88
 
237
89
 
238
90
  def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
@@ -264,19 +116,15 @@ def _wrap_as_tf_func(
264
116
 
265
117
 
266
118
  def _make_tf_function(
267
- shlo_graph_module: stablehlo.StableHLOGraphModule,
268
119
  bundle: stablehlo.StableHLOModelBundle = None,
269
120
  ):
270
- bundle = shlo_graph_module._bundle if bundle is None else bundle
271
- return [
272
- _wrap_as_tf_func(func, bundle)
273
- for func in shlo_graph_module._bundle.stablehlo_funcs
274
- ]
121
+ bundle = bundle if bundle is None else bundle
122
+ return [_wrap_as_tf_func(func, bundle) for func in bundle.stablehlo_funcs]
275
123
 
276
124
 
277
125
  def _make_tf_signature(
278
126
  meta: stablehlo.StableHLOFunctionMeta,
279
- signature: Signature,
127
+ signature: signature_module.Signature,
280
128
  ) -> list[tf.TensorSpec]:
281
129
  input_names = signature.flat_arg_names
282
130
  input_pos_to_spec = {
@@ -305,60 +153,33 @@ def _make_tf_signature(
305
153
  return ret
306
154
 
307
155
 
308
- def _apply_tfl_backdoor_flags(
309
- converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict
310
- ):
311
- def _set_converter_flag(path: list):
312
- if len(path) < 2:
313
- raise ValueError("Expecting at least two values in the path.")
314
-
315
- target_obj = converter
316
- for idx in range(len(path) - 2):
317
- target_obj = getattr(target_obj, path[idx])
318
-
319
- setattr(target_obj, path[-2], path[-1])
320
-
321
- def _iterate_dict_tree(flags_dict: dict, path: list):
322
- for key, value in flags_dict.items():
323
- path.append(key)
324
- if isinstance(value, dict):
325
- _iterate_dict_tree(value, path)
326
- else:
327
- path.append(value)
328
- _set_converter_flag(path)
329
- path.pop()
330
- path.pop()
331
-
332
- _iterate_dict_tree(tfl_converter_flags, [])
333
-
334
-
335
- def _set_tfl_converter_quant_flags(
336
- converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
337
- ):
338
- if quant_config is not None:
339
- quantizer_mode = quant_config._quantizer_mode
340
- if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
341
- converter._experimental_qdq_conversion_mode = "DYNAMIC"
342
- elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
343
- converter._experimental_qdq_conversion_mode = "STATIC"
156
+ def exported_program_to_mlir_text(
157
+ exported_program: torch.export.ExportedProgram,
158
+ ) -> str:
159
+ """Converts a ExportedProgram to a MLIR text."""
160
+ return stablehlo.exported_program_to_stablehlo(
161
+ exported_program
162
+ ).get_stablehlo_text()
344
163
 
345
164
 
346
- def convert_stablehlo_to_tflite(
347
- shlo_graph_module: stablehlo.StableHLOGraphModule,
348
- signatures: list[Signature],
165
+ def merged_bundle_to_tfl_model(
166
+ bundle: stablehlo.StableHLOModelBundle,
167
+ signatures: list[signature_module.Signature],
349
168
  *,
350
169
  quant_config: Optional[qcfg.QuantConfig] = None,
351
170
  _tfl_converter_flags: dict = {},
352
171
  ) -> None:
353
172
  """Converts a StableHLOGraphModule to a tflite model.
354
- Args:
355
- shlo_graph_module - model to export and save
356
- signatures: List of signatures from which names of the signatures is extracted.
173
+
174
+ Args: shlo_bundle - model to export and save
175
+
176
+ signatures: List of signatures from which names of the signatures is
177
+ extracted.
357
178
  quant_config: User-defined quantization method and scheme of the model.
358
- _tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
179
+ _tfl_converter_flags: A nested dictionary allowing setting flags for the
180
+ underlying tflite converter.
359
181
  """
360
182
 
361
- bundle = shlo_graph_module._bundle
362
183
  tf_module = tf.Module()
363
184
  bundle.state_dict = {
364
185
  k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
@@ -371,7 +192,7 @@ def convert_stablehlo_to_tflite(
371
192
  for func, sig in zip(bundle.stablehlo_funcs, signatures)
372
193
  )
373
194
 
374
- tf_functions = _make_tf_function(shlo_graph_module, bundle)
195
+ tf_functions = _make_tf_function(bundle)
375
196
 
376
197
  tf_module.f = []
377
198
  for tf_sig, func in zip(tf_signatures, tf_functions):
@@ -413,7 +234,7 @@ def convert_stablehlo_to_tflite(
413
234
  converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
414
235
  converter._experimental_enable_composite_direct_lowering = True
415
236
 
416
- _set_tfl_converter_quant_flags(converter, quant_config)
237
+ conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
417
238
  if (
418
239
  quant_config is not None
419
240
  and quant_config._quantizer_mode
@@ -423,7 +244,7 @@ def convert_stablehlo_to_tflite(
423
244
  quant_config.generative_recipe
424
245
  )
425
246
 
426
- _apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
247
+ conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
427
248
 
428
249
  tflite_model = converter.convert()
429
250
 
ai_edge_torch/model.py CHANGED
@@ -15,17 +15,18 @@
15
15
 
16
16
  """Represents an ai_edge_torch model.
17
17
 
18
- PyTorch models can be converted to this representation through `ai_edge_torch.convert`.
18
+ PyTorch models can be converted to this representation through
19
+ `ai_edge_torch.convert`.
19
20
  """
20
21
  from __future__ import annotations
21
22
 
22
23
  import abc
23
24
 
24
- from ai_edge_torch.convert import conversion_utils as cutils
25
- import numpy as np
26
25
  import numpy.typing as npt
27
26
  import tensorflow as tf
28
27
 
28
+ DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
29
+
29
30
 
30
31
  class Model(abc.ABC):
31
32
  """Represents and edge model."""
@@ -34,7 +35,7 @@ class Model(abc.ABC):
34
35
  def __call__(
35
36
  self,
36
37
  *args: npt.ArrayLike,
37
- signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
38
+ signature_name: str = DEFAULT_SIGNATURE_NAME,
38
39
  **kwargs,
39
40
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
40
41
  raise NotImplementedError()
@@ -66,18 +67,22 @@ class TfLiteModel(Model):
66
67
  def __call__(
67
68
  self,
68
69
  *args: npt.ArrayLike,
69
- signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
70
+ signature_name: str = DEFAULT_SIGNATURE_NAME,
70
71
  **kwargs,
71
72
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
72
73
  """Runs inference on the edge model using the provided arguments.
73
74
 
74
75
  Args:
75
76
  *args: The arguments to be passed to the model for inference.
76
- **kwargs: The arguments with specific names to be passed to the model for inference.
77
- signature_name: The name of the signature to be used for inference.
78
- The default signature is used if not provided.
77
+ **kwargs: The arguments with specific names to be passed to the model for
78
+ inference.
79
+ signature_name: The name of the signature to be used for inference. The
80
+ default signature is used if not provided.
79
81
  """
80
- interpreter = tf.lite.Interpreter(model_content=self._tflite_model)
82
+ interpreter = tf.lite.Interpreter(
83
+ model_content=self._tflite_model,
84
+ experimental_default_delegate_latest_features=True,
85
+ )
81
86
  interpreter.allocate_tensors()
82
87
 
83
88
  signature_list = interpreter.get_signature_list()