ai-edge-torch-nightly 0.2.0.dev20240710__py3-none-any.whl → 0.2.0.dev20240713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +2 -4
- ai_edge_torch/convert/conversion_utils.py +61 -3
- ai_edge_torch/convert/converter.py +47 -16
- ai_edge_torch/convert/test/test_convert.py +39 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -10
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +56 -30
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +72 -69
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +80 -72
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +1 -1
- ai_edge_torch/generative/examples/t5/t5_attention.py +6 -1
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +1 -1
- ai_edge_torch/generative/layers/model_config.py +4 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +1 -1
- ai_edge_torch/generative/layers/unet/model_config.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +9 -6
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +74 -10
- ai_edge_torch/model.py +11 -3
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -13
- {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/RECORD +23 -23
- {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/top_level.txt +0 -0
|
@@ -88,16 +88,14 @@ def convert_signatures(
|
|
|
88
88
|
_warn_training_modules(signatures)
|
|
89
89
|
|
|
90
90
|
exported_programs: torch.export.ExportedProgram = [
|
|
91
|
-
torch.export.export(
|
|
92
|
-
sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
|
|
93
|
-
)
|
|
91
|
+
torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
|
|
94
92
|
for sig in signatures
|
|
95
93
|
]
|
|
96
94
|
|
|
97
95
|
# Apply default fx passes
|
|
98
96
|
exported_programs = list(map(_run_convert_passes, exported_programs))
|
|
99
97
|
shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
|
|
100
|
-
cutils.exported_program_to_stablehlo_bundle(exported, sig.
|
|
98
|
+
cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
|
|
101
99
|
for exported, sig in zip(exported_programs, signatures)
|
|
102
100
|
]
|
|
103
101
|
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
import collections
|
|
16
17
|
import copy
|
|
17
18
|
from dataclasses import dataclass
|
|
18
19
|
import gc
|
|
@@ -22,6 +23,7 @@ import tempfile
|
|
|
22
23
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
23
24
|
|
|
24
25
|
import torch
|
|
26
|
+
import torch.utils._pytree as pytree
|
|
25
27
|
from torch_xla import stablehlo
|
|
26
28
|
|
|
27
29
|
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
@@ -47,8 +49,59 @@ class Signature:
|
|
|
47
49
|
name: str
|
|
48
50
|
module: torch.nn.Module
|
|
49
51
|
sample_args: tuple[torch.Tensor]
|
|
52
|
+
sample_kwargs: dict[str, torch.Tensor]
|
|
50
53
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
51
54
|
|
|
55
|
+
@property
|
|
56
|
+
def _normalized_sample_args_kwargs(self):
|
|
57
|
+
args, kwargs = self.sample_args, self.sample_kwargs
|
|
58
|
+
if args is not None:
|
|
59
|
+
if not isinstance(args, tuple):
|
|
60
|
+
# TODO(b/352584188): Check value types
|
|
61
|
+
raise ValueError("sample_args must be a tuple of torch tensors.")
|
|
62
|
+
if kwargs is not None:
|
|
63
|
+
if not isinstance(kwargs, dict) or not all(
|
|
64
|
+
isinstance(key, str) for key in kwargs.keys()
|
|
65
|
+
):
|
|
66
|
+
# TODO(b/352584188): Check value types
|
|
67
|
+
raise ValueError("sample_kwargs must be a dict of string to tensor.")
|
|
68
|
+
|
|
69
|
+
args = args if args is not None else tuple()
|
|
70
|
+
kwargs = kwargs if kwargs is not None else {}
|
|
71
|
+
return args, kwargs
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def flat_arg_names(self) -> list[str]:
|
|
75
|
+
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
|
|
76
|
+
args_spec, kwargs_spec = spec.children_specs
|
|
77
|
+
|
|
78
|
+
names = []
|
|
79
|
+
for i in range(args_spec.num_leaves):
|
|
80
|
+
names.append(f"args_{i}")
|
|
81
|
+
|
|
82
|
+
dict_context = (
|
|
83
|
+
kwargs_spec.context
|
|
84
|
+
if kwargs_spec.type is not collections.defaultdict
|
|
85
|
+
# ignore mismatch of `default_factory` for defaultdict
|
|
86
|
+
else kwargs_spec.context[1]
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
|
|
90
|
+
if value_spec.num_leaves == 1:
|
|
91
|
+
names.append(name)
|
|
92
|
+
else:
|
|
93
|
+
# value_spec.num_leaves may be greater than 1 when the value is a (nested)
|
|
94
|
+
# tuple of tensors. We haven't decided how we should support flattenable
|
|
95
|
+
# tensor containers as inputs.
|
|
96
|
+
# TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
|
|
97
|
+
for i in range(value_spec.num_leaves):
|
|
98
|
+
names.append(f"{name}_{i}")
|
|
99
|
+
return names
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def flat_args(self) -> tuple[torch.Tensor]:
|
|
103
|
+
return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
|
|
104
|
+
|
|
52
105
|
|
|
53
106
|
def exported_program_to_stablehlo_bundle(
|
|
54
107
|
exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
|
|
@@ -189,7 +242,9 @@ def _make_tf_function(
|
|
|
189
242
|
|
|
190
243
|
def _make_tf_signature(
|
|
191
244
|
meta: stablehlo.StableHLOFunctionMeta,
|
|
245
|
+
signature: Signature,
|
|
192
246
|
) -> list[tf.TensorSpec]:
|
|
247
|
+
input_names = signature.flat_arg_names
|
|
193
248
|
input_pos_to_spec = {
|
|
194
249
|
loc.position: spec
|
|
195
250
|
for loc, spec in itertools.chain(
|
|
@@ -197,9 +252,11 @@ def _make_tf_signature(
|
|
|
197
252
|
)
|
|
198
253
|
if loc.type_ == stablehlo.VariableType.INPUT_ARG
|
|
199
254
|
}
|
|
255
|
+
assert len(input_pos_to_spec) == len(input_names)
|
|
256
|
+
|
|
200
257
|
primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
|
|
201
258
|
ret: list[tf.TensorSpec] = []
|
|
202
|
-
for i in
|
|
259
|
+
for i, name in enumerate(input_names):
|
|
203
260
|
spec = input_pos_to_spec[i]
|
|
204
261
|
shape = _get_shape_with_dynamic(spec)
|
|
205
262
|
ret.append(
|
|
@@ -208,7 +265,7 @@ def _make_tf_signature(
|
|
|
208
265
|
dtype=primitive_type_to_tf_type[spec.dtype]
|
|
209
266
|
if spec.dtype in primitive_type_to_tf_type
|
|
210
267
|
else spec.dtype,
|
|
211
|
-
name=
|
|
268
|
+
name=name,
|
|
212
269
|
)
|
|
213
270
|
)
|
|
214
271
|
return ret
|
|
@@ -276,7 +333,8 @@ def convert_stablehlo_to_tflite(
|
|
|
276
333
|
tf.Variable(v, trainable=False) for v in bundle.additional_constants
|
|
277
334
|
]
|
|
278
335
|
tf_signatures: list[list[tf.TensorSpec]] = list(
|
|
279
|
-
_make_tf_signature(func.meta)
|
|
336
|
+
_make_tf_signature(func.meta, sig)
|
|
337
|
+
for func, sig in zip(bundle.stablehlo_funcs, signatures)
|
|
280
338
|
)
|
|
281
339
|
|
|
282
340
|
tf_functions = _make_tf_function(shlo_graph_module, bundle)
|
|
@@ -34,17 +34,23 @@ class Converter:
|
|
|
34
34
|
self,
|
|
35
35
|
name: str,
|
|
36
36
|
module: torch.nn.Module,
|
|
37
|
-
sample_args
|
|
37
|
+
sample_args=None,
|
|
38
|
+
sample_kwargs=None,
|
|
39
|
+
*,
|
|
38
40
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
39
41
|
) -> Converter:
|
|
40
42
|
"""Alias to `add_signature`"""
|
|
41
|
-
return self.add_signature(
|
|
43
|
+
return self.add_signature(
|
|
44
|
+
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
45
|
+
)
|
|
42
46
|
|
|
43
47
|
def add_signature(
|
|
44
48
|
self,
|
|
45
49
|
name: str,
|
|
46
50
|
module: torch.nn.Module,
|
|
47
|
-
sample_args
|
|
51
|
+
sample_args=None,
|
|
52
|
+
sample_kwargs=None,
|
|
53
|
+
*,
|
|
48
54
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
49
55
|
) -> Converter:
|
|
50
56
|
"""Allows adding a new named torch model along with sample args to the conversion.
|
|
@@ -52,7 +58,8 @@ class Converter:
|
|
|
52
58
|
Args:
|
|
53
59
|
name: The name of the signature included in the converted edge model.
|
|
54
60
|
module: The torch module to be converted.
|
|
55
|
-
sample_args: Tuple of
|
|
61
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
62
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
56
63
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
57
64
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
58
65
|
|
|
@@ -63,13 +70,21 @@ class Converter:
|
|
|
63
70
|
if name in [sig.name for sig in self._signatures]:
|
|
64
71
|
raise ValueError(f"A signature with the provided name ({name}) is already added.")
|
|
65
72
|
|
|
66
|
-
|
|
73
|
+
if sample_args is None and sample_kwargs is None:
|
|
74
|
+
raise ValueError("sample_args or sample_kwargs must be provided.")
|
|
75
|
+
|
|
76
|
+
self._signatures.append(
|
|
77
|
+
cutils.Signature(
|
|
78
|
+
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
79
|
+
)
|
|
80
|
+
)
|
|
67
81
|
return self
|
|
68
82
|
|
|
69
83
|
def convert(
|
|
70
84
|
self,
|
|
71
85
|
module: torch.nn.Module = None,
|
|
72
|
-
sample_args
|
|
86
|
+
sample_args=None,
|
|
87
|
+
sample_kwargs=None,
|
|
73
88
|
*,
|
|
74
89
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
75
90
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
@@ -88,7 +103,8 @@ class Converter:
|
|
|
88
103
|
Args:
|
|
89
104
|
name: The name of the signature included in the converted edge model.
|
|
90
105
|
module: The torch module to be converted.
|
|
91
|
-
sample_args: Tuple of
|
|
106
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
107
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
92
108
|
quant_config: User-defined quantization method and scheme of the model.
|
|
93
109
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
94
110
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
@@ -100,12 +116,20 @@ class Converter:
|
|
|
100
116
|
ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
|
|
101
117
|
"""
|
|
102
118
|
if module is not None:
|
|
103
|
-
if
|
|
119
|
+
if (
|
|
120
|
+
sample_args is not None or sample_kwargs is not None
|
|
121
|
+
): # both module and args provided
|
|
104
122
|
self.add_signature(
|
|
105
|
-
cutils.DEFAULT_SIGNATURE_NAME,
|
|
123
|
+
cutils.DEFAULT_SIGNATURE_NAME,
|
|
124
|
+
module,
|
|
125
|
+
sample_args,
|
|
126
|
+
sample_kwargs,
|
|
127
|
+
dynamic_shapes=dynamic_shapes,
|
|
128
|
+
)
|
|
129
|
+
else: # module is provided but not args
|
|
130
|
+
raise ValueError(
|
|
131
|
+
"sample_args or sample_kwargs must be provided if a module is specified."
|
|
106
132
|
)
|
|
107
|
-
else: # module is provided but not sample_args
|
|
108
|
-
raise ValueError("sample_args needs to be provided if a module is specified.")
|
|
109
133
|
|
|
110
134
|
return conversion.convert_signatures(
|
|
111
135
|
self._signatures,
|
|
@@ -117,7 +141,8 @@ class Converter:
|
|
|
117
141
|
def signature(
|
|
118
142
|
name: str,
|
|
119
143
|
module: torch.nn.Module,
|
|
120
|
-
sample_args
|
|
144
|
+
sample_args=None,
|
|
145
|
+
sample_kwargs=None,
|
|
121
146
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
122
147
|
) -> Converter:
|
|
123
148
|
"""Initiates a Converter object with the provided signature.
|
|
@@ -125,7 +150,8 @@ def signature(
|
|
|
125
150
|
Args:
|
|
126
151
|
name: The name of the signature included in the converted edge model.
|
|
127
152
|
module: The torch module to be converted.
|
|
128
|
-
sample_args: Tuple of
|
|
153
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
154
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
129
155
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
130
156
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
131
157
|
|
|
@@ -134,12 +160,15 @@ def signature(
|
|
|
134
160
|
edge_model = converter.convert()
|
|
135
161
|
|
|
136
162
|
"""
|
|
137
|
-
return Converter().signature(
|
|
163
|
+
return Converter().signature(
|
|
164
|
+
name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
|
|
165
|
+
)
|
|
138
166
|
|
|
139
167
|
|
|
140
168
|
def convert(
|
|
141
169
|
module: torch.nn.Module = None,
|
|
142
|
-
sample_args
|
|
170
|
+
sample_args=None,
|
|
171
|
+
sample_kwargs=None,
|
|
143
172
|
*,
|
|
144
173
|
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
145
174
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
|
|
@@ -149,7 +178,8 @@ def convert(
|
|
|
149
178
|
|
|
150
179
|
Args:
|
|
151
180
|
module: The torch module to be converted.
|
|
152
|
-
sample_args: Tuple of
|
|
181
|
+
sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
|
|
182
|
+
sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
|
|
153
183
|
quant_config: User-defined quantization method and scheme of the model.
|
|
154
184
|
dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
|
|
155
185
|
See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
|
|
@@ -165,6 +195,7 @@ def convert(
|
|
|
165
195
|
return Converter().convert(
|
|
166
196
|
module,
|
|
167
197
|
sample_args,
|
|
198
|
+
sample_kwargs,
|
|
168
199
|
quant_config=quant_config,
|
|
169
200
|
dynamic_shapes=dynamic_shapes,
|
|
170
201
|
_ai_edge_converter_flags=_ai_edge_converter_flags,
|
|
@@ -267,6 +267,45 @@ class TestConvert(unittest.TestCase):
|
|
|
267
267
|
model_coverage.compare_tflite_torch(edge_model, model, validate_input)
|
|
268
268
|
)
|
|
269
269
|
|
|
270
|
+
def test_convert_model_with_kwargs(self):
|
|
271
|
+
"""
|
|
272
|
+
Test converting a simple model with sample_kwargs.
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
class SampleModel(torch.nn.Module):
|
|
276
|
+
|
|
277
|
+
def forward(self, x, y):
|
|
278
|
+
return x + y
|
|
279
|
+
|
|
280
|
+
kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
|
|
281
|
+
|
|
282
|
+
model = SampleModel().eval()
|
|
283
|
+
edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
|
|
284
|
+
|
|
285
|
+
self.assertTrue(
|
|
286
|
+
model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def test_convert_model_with_args_kwargs(self):
|
|
290
|
+
"""
|
|
291
|
+
Test converting a simple model with both sample_args and sample_kwargs.
|
|
292
|
+
"""
|
|
293
|
+
|
|
294
|
+
class SampleModel(torch.nn.Module):
|
|
295
|
+
|
|
296
|
+
def forward(self, x, y):
|
|
297
|
+
return x + y
|
|
298
|
+
|
|
299
|
+
args_gen = lambda: (torch.randn(10, 10),)
|
|
300
|
+
kwargs_gen = lambda: dict(y=torch.randn(10, 10))
|
|
301
|
+
|
|
302
|
+
model = SampleModel().eval()
|
|
303
|
+
edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
|
|
304
|
+
|
|
305
|
+
self.assertTrue(
|
|
306
|
+
model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
|
|
307
|
+
)
|
|
308
|
+
|
|
270
309
|
|
|
271
310
|
if __name__ == "__main__":
|
|
272
311
|
unittest.main()
|
|
@@ -23,16 +23,17 @@ import ai_edge_torch.generative.layers.model_config as cfg
|
|
|
23
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
24
24
|
|
|
25
25
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
26
|
-
ff_up_proj="layers.{}.
|
|
27
|
-
ff_down_proj="layers.{}.
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
26
|
+
ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
|
|
27
|
+
ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
|
|
28
|
+
attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
|
|
29
|
+
attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
|
|
30
|
+
attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
|
|
31
|
+
attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
|
|
32
|
+
pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
|
|
33
|
+
pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
|
|
34
|
+
embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
|
|
35
|
+
embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
|
36
|
+
final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
|
|
36
37
|
lm_head=None,
|
|
37
38
|
)
|
|
38
39
|
|
|
@@ -84,6 +85,7 @@ def get_model_config() -> cfg.ModelConfig:
|
|
|
84
85
|
rotary_percentage=0.0,
|
|
85
86
|
qkv_use_bias=True,
|
|
86
87
|
qkv_transpose_before_split=True,
|
|
88
|
+
qkv_fused_interleaved=False,
|
|
87
89
|
output_proj_use_bias=True,
|
|
88
90
|
enable_kv_cache=False,
|
|
89
91
|
)
|
|
@@ -13,8 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
import argparse
|
|
16
17
|
import os
|
|
17
18
|
from pathlib import Path
|
|
19
|
+
from typing import Optional
|
|
18
20
|
|
|
19
21
|
import torch
|
|
20
22
|
|
|
@@ -24,14 +26,36 @@ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
|
24
26
|
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
|
|
25
27
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
26
28
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
28
29
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
29
30
|
|
|
31
|
+
arg_parser = argparse.ArgumentParser()
|
|
32
|
+
arg_parser.add_argument(
|
|
33
|
+
'--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
|
|
34
|
+
)
|
|
35
|
+
arg_parser.add_argument(
|
|
36
|
+
'--diffusion_ckpt',
|
|
37
|
+
type=str,
|
|
38
|
+
help='Path to source diffusion model checkpoint',
|
|
39
|
+
required=True,
|
|
40
|
+
)
|
|
41
|
+
arg_parser.add_argument(
|
|
42
|
+
'--decoder_ckpt',
|
|
43
|
+
type=str,
|
|
44
|
+
help='Path to source image decoder model checkpoint',
|
|
45
|
+
required=True,
|
|
46
|
+
)
|
|
47
|
+
arg_parser.add_argument(
|
|
48
|
+
'--output_dir',
|
|
49
|
+
type=str,
|
|
50
|
+
help='Path to the converted TF Lite directory.',
|
|
51
|
+
required=True,
|
|
52
|
+
)
|
|
53
|
+
|
|
30
54
|
|
|
31
55
|
@torch.inference_mode
|
|
32
56
|
def convert_stable_diffusion_to_tflite(
|
|
57
|
+
output_dir: str,
|
|
33
58
|
clip_ckpt_path: str,
|
|
34
|
-
encoder_ckpt_path: str,
|
|
35
59
|
diffusion_ckpt_path: str,
|
|
36
60
|
decoder_ckpt_path: str,
|
|
37
61
|
image_height: int = 512,
|
|
@@ -39,23 +63,28 @@ def convert_stable_diffusion_to_tflite(
|
|
|
39
63
|
):
|
|
40
64
|
|
|
41
65
|
clip_model = clip.CLIP(clip.get_model_config())
|
|
42
|
-
loader =
|
|
66
|
+
loader = stable_diffusion_loader.ClipModelLoader(
|
|
67
|
+
clip_ckpt_path,
|
|
68
|
+
clip.TENSOR_NAMES,
|
|
69
|
+
)
|
|
43
70
|
loader.load(clip_model, strict=False)
|
|
44
71
|
|
|
45
|
-
encoder = Encoder()
|
|
46
|
-
encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
47
|
-
|
|
48
72
|
diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
|
|
49
73
|
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
|
|
50
|
-
diffusion_ckpt_path, diffusion.
|
|
74
|
+
diffusion_ckpt_path, diffusion.TENSOR_NAMES
|
|
51
75
|
)
|
|
52
|
-
diffusion_loader.load(diffusion_model)
|
|
76
|
+
diffusion_loader.load(diffusion_model, strict=False)
|
|
53
77
|
|
|
54
78
|
decoder_model = decoder.Decoder(decoder.get_model_config())
|
|
55
79
|
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
|
|
56
|
-
decoder_ckpt_path, decoder.
|
|
80
|
+
decoder_ckpt_path, decoder.TENSOR_NAMES
|
|
57
81
|
)
|
|
58
|
-
decoder_loader.load(decoder_model)
|
|
82
|
+
decoder_loader.load(decoder_model, strict=False)
|
|
83
|
+
|
|
84
|
+
# TODO(yichunk): enable image encoder conversion
|
|
85
|
+
# if encoder_ckpt_path is not None:
|
|
86
|
+
# encoder = Encoder()
|
|
87
|
+
# encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
59
88
|
|
|
60
89
|
# Tensors used to trace the model graph during conversion.
|
|
61
90
|
n_tokens = 77
|
|
@@ -67,50 +96,47 @@ def convert_stable_diffusion_to_tflite(
|
|
|
67
96
|
(len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
|
|
68
97
|
)
|
|
69
98
|
|
|
70
|
-
input_latents =
|
|
99
|
+
input_latents = torch.zeros_like(noise)
|
|
71
100
|
context_cond = clip_model(prompt_tokens)
|
|
72
101
|
context_uncond = torch.zeros_like(context_cond)
|
|
73
102
|
context = torch.cat([context_cond, context_uncond], axis=0)
|
|
74
103
|
time_embedding = util.get_time_embedding(timestamp)
|
|
75
104
|
|
|
105
|
+
if not os.path.exists(output_dir):
|
|
106
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
# TODO(yichunk): convert to multi signature tflite model.
|
|
76
109
|
# CLIP text encoder
|
|
77
110
|
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
|
|
78
|
-
'/
|
|
111
|
+
f'{output_dir}/clip.tflite'
|
|
79
112
|
)
|
|
80
113
|
|
|
81
|
-
# TODO(yichunk):
|
|
114
|
+
# TODO(yichunk): enable image encoder conversion
|
|
82
115
|
# Image encoder
|
|
83
|
-
ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
|
|
84
|
-
|
|
85
|
-
)
|
|
116
|
+
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
|
|
117
|
+
# f'{output_dir}/encoder.tflite'
|
|
118
|
+
# )
|
|
86
119
|
|
|
87
120
|
# Diffusion
|
|
88
121
|
ai_edge_torch.signature(
|
|
89
122
|
'diffusion',
|
|
90
123
|
diffusion_model,
|
|
91
124
|
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
|
|
92
|
-
).convert().export('/
|
|
125
|
+
).convert().export(f'{output_dir}/diffusion.tflite')
|
|
93
126
|
|
|
94
127
|
# Image decoder
|
|
95
128
|
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
|
|
96
|
-
'/
|
|
129
|
+
f'{output_dir}/decoder.tflite'
|
|
97
130
|
)
|
|
98
131
|
|
|
99
132
|
|
|
100
133
|
if __name__ == '__main__':
|
|
134
|
+
args = arg_parser.parse_args()
|
|
101
135
|
convert_stable_diffusion_to_tflite(
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
Path.home(), 'Downloads/stable_diffusion_data/ckpt/encoder.pt'
|
|
107
|
-
),
|
|
108
|
-
diffusion_ckpt_path=os.path.join(
|
|
109
|
-
Path.home(), 'Downloads/stable_diffusion_data/ckpt/diffusion.pt'
|
|
110
|
-
),
|
|
111
|
-
decoder_ckpt_path=os.path.join(
|
|
112
|
-
Path.home(), 'Downloads/stable_diffusion_data/ckpt/decoder.pt'
|
|
113
|
-
),
|
|
136
|
+
output_dir=args.output_dir,
|
|
137
|
+
clip_ckpt_path=args.clip_ckpt,
|
|
138
|
+
diffusion_ckpt_path=args.diffusion_ckpt,
|
|
139
|
+
decoder_ckpt_path=args.decoder_ckpt,
|
|
114
140
|
image_height=512,
|
|
115
141
|
image_width=512,
|
|
116
142
|
)
|