ai-edge-torch-nightly 0.2.0.dev20240707__py3-none-any.whl → 0.2.0.dev20240713__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (23) hide show
  1. ai_edge_torch/convert/conversion.py +2 -4
  2. ai_edge_torch/convert/conversion_utils.py +61 -3
  3. ai_edge_torch/convert/converter.py +47 -16
  4. ai_edge_torch/convert/test/test_convert.py +39 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -10
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +56 -30
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +72 -69
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +80 -72
  9. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +1 -1
  10. ai_edge_torch/generative/examples/t5/t5_attention.py +6 -1
  11. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +1 -1
  12. ai_edge_torch/generative/layers/model_config.py +4 -0
  13. ai_edge_torch/generative/layers/unet/blocks_2d.py +1 -1
  14. ai_edge_torch/generative/layers/unet/model_config.py +5 -5
  15. ai_edge_torch/generative/utilities/loader.py +9 -6
  16. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +74 -10
  17. ai_edge_torch/model.py +11 -3
  18. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -13
  19. {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/RECORD +23 -23
  21. {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/top_level.txt +0 -0
@@ -88,16 +88,14 @@ def convert_signatures(
88
88
  _warn_training_modules(signatures)
89
89
 
90
90
  exported_programs: torch.export.ExportedProgram = [
91
- torch.export.export(
92
- sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
93
- )
91
+ torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
94
92
  for sig in signatures
95
93
  ]
96
94
 
97
95
  # Apply default fx passes
98
96
  exported_programs = list(map(_run_convert_passes, exported_programs))
99
97
  shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
100
- cutils.exported_program_to_stablehlo_bundle(exported, sig.sample_args)
98
+ cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
101
99
  for exported, sig in zip(exported_programs, signatures)
102
100
  ]
103
101
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import collections
16
17
  import copy
17
18
  from dataclasses import dataclass
18
19
  import gc
@@ -22,6 +23,7 @@ import tempfile
22
23
  from typing import Any, Dict, Optional, Tuple, Union
23
24
 
24
25
  import torch
26
+ import torch.utils._pytree as pytree
25
27
  from torch_xla import stablehlo
26
28
 
27
29
  from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
@@ -47,8 +49,59 @@ class Signature:
47
49
  name: str
48
50
  module: torch.nn.Module
49
51
  sample_args: tuple[torch.Tensor]
52
+ sample_kwargs: dict[str, torch.Tensor]
50
53
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
51
54
 
55
+ @property
56
+ def _normalized_sample_args_kwargs(self):
57
+ args, kwargs = self.sample_args, self.sample_kwargs
58
+ if args is not None:
59
+ if not isinstance(args, tuple):
60
+ # TODO(b/352584188): Check value types
61
+ raise ValueError("sample_args must be a tuple of torch tensors.")
62
+ if kwargs is not None:
63
+ if not isinstance(kwargs, dict) or not all(
64
+ isinstance(key, str) for key in kwargs.keys()
65
+ ):
66
+ # TODO(b/352584188): Check value types
67
+ raise ValueError("sample_kwargs must be a dict of string to tensor.")
68
+
69
+ args = args if args is not None else tuple()
70
+ kwargs = kwargs if kwargs is not None else {}
71
+ return args, kwargs
72
+
73
+ @property
74
+ def flat_arg_names(self) -> list[str]:
75
+ spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
76
+ args_spec, kwargs_spec = spec.children_specs
77
+
78
+ names = []
79
+ for i in range(args_spec.num_leaves):
80
+ names.append(f"args_{i}")
81
+
82
+ dict_context = (
83
+ kwargs_spec.context
84
+ if kwargs_spec.type is not collections.defaultdict
85
+ # ignore mismatch of `default_factory` for defaultdict
86
+ else kwargs_spec.context[1]
87
+ )
88
+
89
+ for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
90
+ if value_spec.num_leaves == 1:
91
+ names.append(name)
92
+ else:
93
+ # value_spec.num_leaves may be greater than 1 when the value is a (nested)
94
+ # tuple of tensors. We haven't decided how we should support flattenable
95
+ # tensor containers as inputs.
96
+ # TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
97
+ for i in range(value_spec.num_leaves):
98
+ names.append(f"{name}_{i}")
99
+ return names
100
+
101
+ @property
102
+ def flat_args(self) -> tuple[torch.Tensor]:
103
+ return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
104
+
52
105
 
53
106
  def exported_program_to_stablehlo_bundle(
54
107
  exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
@@ -189,7 +242,9 @@ def _make_tf_function(
189
242
 
190
243
  def _make_tf_signature(
191
244
  meta: stablehlo.StableHLOFunctionMeta,
245
+ signature: Signature,
192
246
  ) -> list[tf.TensorSpec]:
247
+ input_names = signature.flat_arg_names
193
248
  input_pos_to_spec = {
194
249
  loc.position: spec
195
250
  for loc, spec in itertools.chain(
@@ -197,9 +252,11 @@ def _make_tf_signature(
197
252
  )
198
253
  if loc.type_ == stablehlo.VariableType.INPUT_ARG
199
254
  }
255
+ assert len(input_pos_to_spec) == len(input_names)
256
+
200
257
  primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
201
258
  ret: list[tf.TensorSpec] = []
202
- for i in range(len(input_pos_to_spec)):
259
+ for i, name in enumerate(input_names):
203
260
  spec = input_pos_to_spec[i]
204
261
  shape = _get_shape_with_dynamic(spec)
205
262
  ret.append(
@@ -208,7 +265,7 @@ def _make_tf_signature(
208
265
  dtype=primitive_type_to_tf_type[spec.dtype]
209
266
  if spec.dtype in primitive_type_to_tf_type
210
267
  else spec.dtype,
211
- name=f"args_{i}",
268
+ name=name,
212
269
  )
213
270
  )
214
271
  return ret
@@ -276,7 +333,8 @@ def convert_stablehlo_to_tflite(
276
333
  tf.Variable(v, trainable=False) for v in bundle.additional_constants
277
334
  ]
278
335
  tf_signatures: list[list[tf.TensorSpec]] = list(
279
- _make_tf_signature(func.meta) for func in bundle.stablehlo_funcs
336
+ _make_tf_signature(func.meta, sig)
337
+ for func, sig in zip(bundle.stablehlo_funcs, signatures)
280
338
  )
281
339
 
282
340
  tf_functions = _make_tf_function(shlo_graph_module, bundle)
@@ -34,17 +34,23 @@ class Converter:
34
34
  self,
35
35
  name: str,
36
36
  module: torch.nn.Module,
37
- sample_args: tuple[cutils.TracingArg],
37
+ sample_args=None,
38
+ sample_kwargs=None,
39
+ *,
38
40
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
39
41
  ) -> Converter:
40
42
  """Alias to `add_signature`"""
41
- return self.add_signature(name, module, sample_args, dynamic_shapes)
43
+ return self.add_signature(
44
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
45
+ )
42
46
 
43
47
  def add_signature(
44
48
  self,
45
49
  name: str,
46
50
  module: torch.nn.Module,
47
- sample_args: tuple[cutils.TracingArg],
51
+ sample_args=None,
52
+ sample_kwargs=None,
53
+ *,
48
54
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
49
55
  ) -> Converter:
50
56
  """Allows adding a new named torch model along with sample args to the conversion.
@@ -52,7 +58,8 @@ class Converter:
52
58
  Args:
53
59
  name: The name of the signature included in the converted edge model.
54
60
  module: The torch module to be converted.
55
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
61
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
62
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
56
63
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
57
64
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
58
65
 
@@ -63,13 +70,21 @@ class Converter:
63
70
  if name in [sig.name for sig in self._signatures]:
64
71
  raise ValueError(f"A signature with the provided name ({name}) is already added.")
65
72
 
66
- self._signatures.append(cutils.Signature(name, module, sample_args, dynamic_shapes))
73
+ if sample_args is None and sample_kwargs is None:
74
+ raise ValueError("sample_args or sample_kwargs must be provided.")
75
+
76
+ self._signatures.append(
77
+ cutils.Signature(
78
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
79
+ )
80
+ )
67
81
  return self
68
82
 
69
83
  def convert(
70
84
  self,
71
85
  module: torch.nn.Module = None,
72
- sample_args: tuple[cutils.TracingArg] = None,
86
+ sample_args=None,
87
+ sample_kwargs=None,
73
88
  *,
74
89
  quant_config: Optional[qcfg.QuantConfig] = None,
75
90
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
@@ -88,7 +103,8 @@ class Converter:
88
103
  Args:
89
104
  name: The name of the signature included in the converted edge model.
90
105
  module: The torch module to be converted.
91
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
106
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
107
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
92
108
  quant_config: User-defined quantization method and scheme of the model.
93
109
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
94
110
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
@@ -100,12 +116,20 @@ class Converter:
100
116
  ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
101
117
  """
102
118
  if module is not None:
103
- if sample_args is not None: # both module and args provided
119
+ if (
120
+ sample_args is not None or sample_kwargs is not None
121
+ ): # both module and args provided
104
122
  self.add_signature(
105
- cutils.DEFAULT_SIGNATURE_NAME, module, sample_args, dynamic_shapes
123
+ cutils.DEFAULT_SIGNATURE_NAME,
124
+ module,
125
+ sample_args,
126
+ sample_kwargs,
127
+ dynamic_shapes=dynamic_shapes,
128
+ )
129
+ else: # module is provided but not args
130
+ raise ValueError(
131
+ "sample_args or sample_kwargs must be provided if a module is specified."
106
132
  )
107
- else: # module is provided but not sample_args
108
- raise ValueError("sample_args needs to be provided if a module is specified.")
109
133
 
110
134
  return conversion.convert_signatures(
111
135
  self._signatures,
@@ -117,7 +141,8 @@ class Converter:
117
141
  def signature(
118
142
  name: str,
119
143
  module: torch.nn.Module,
120
- sample_args: tuple[cutils.TracingArg],
144
+ sample_args=None,
145
+ sample_kwargs=None,
121
146
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
122
147
  ) -> Converter:
123
148
  """Initiates a Converter object with the provided signature.
@@ -125,7 +150,8 @@ def signature(
125
150
  Args:
126
151
  name: The name of the signature included in the converted edge model.
127
152
  module: The torch module to be converted.
128
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
153
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
154
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
129
155
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
130
156
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
131
157
 
@@ -134,12 +160,15 @@ def signature(
134
160
  edge_model = converter.convert()
135
161
 
136
162
  """
137
- return Converter().signature(name, module, sample_args, dynamic_shapes)
163
+ return Converter().signature(
164
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
165
+ )
138
166
 
139
167
 
140
168
  def convert(
141
169
  module: torch.nn.Module = None,
142
- sample_args: tuple[cutils.TracingArg] = None,
170
+ sample_args=None,
171
+ sample_kwargs=None,
143
172
  *,
144
173
  quant_config: Optional[qcfg.QuantConfig] = None,
145
174
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
@@ -149,7 +178,8 @@ def convert(
149
178
 
150
179
  Args:
151
180
  module: The torch module to be converted.
152
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
181
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
182
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
153
183
  quant_config: User-defined quantization method and scheme of the model.
154
184
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
155
185
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
@@ -165,6 +195,7 @@ def convert(
165
195
  return Converter().convert(
166
196
  module,
167
197
  sample_args,
198
+ sample_kwargs,
168
199
  quant_config=quant_config,
169
200
  dynamic_shapes=dynamic_shapes,
170
201
  _ai_edge_converter_flags=_ai_edge_converter_flags,
@@ -267,6 +267,45 @@ class TestConvert(unittest.TestCase):
267
267
  model_coverage.compare_tflite_torch(edge_model, model, validate_input)
268
268
  )
269
269
 
270
+ def test_convert_model_with_kwargs(self):
271
+ """
272
+ Test converting a simple model with sample_kwargs.
273
+ """
274
+
275
+ class SampleModel(torch.nn.Module):
276
+
277
+ def forward(self, x, y):
278
+ return x + y
279
+
280
+ kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
281
+
282
+ model = SampleModel().eval()
283
+ edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
284
+
285
+ self.assertTrue(
286
+ model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
287
+ )
288
+
289
+ def test_convert_model_with_args_kwargs(self):
290
+ """
291
+ Test converting a simple model with both sample_args and sample_kwargs.
292
+ """
293
+
294
+ class SampleModel(torch.nn.Module):
295
+
296
+ def forward(self, x, y):
297
+ return x + y
298
+
299
+ args_gen = lambda: (torch.randn(10, 10),)
300
+ kwargs_gen = lambda: dict(y=torch.randn(10, 10))
301
+
302
+ model = SampleModel().eval()
303
+ edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
304
+
305
+ self.assertTrue(
306
+ model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
307
+ )
308
+
270
309
 
271
310
  if __name__ == "__main__":
272
311
  unittest.main()
@@ -23,16 +23,17 @@ import ai_edge_torch.generative.layers.model_config as cfg
23
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
24
 
25
25
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
- ff_up_proj="layers.{}.linear_1",
27
- ff_down_proj="layers.{}.linear_2",
28
- ff_gate_proj="layers.{}.linear_1",
29
- attn_fused_qkv_proj="layers.{}.attention.in_proj",
30
- attn_output_proj="layers.{}.attention.out_proj",
31
- pre_attn_norm="layers.{}.layernorm_1",
32
- pre_ff_norm="layers.{}.layernorm_2",
33
- embedding="embedding.token_embedding",
34
- embedding_position="embedding.position_value",
35
- final_norm="layernorm",
26
+ ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
27
+ ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
28
+ attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
29
+ attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
30
+ attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
31
+ attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
32
+ pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
33
+ pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
34
+ embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
35
+ embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
36
+ final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
36
37
  lm_head=None,
37
38
  )
38
39
 
@@ -84,6 +85,7 @@ def get_model_config() -> cfg.ModelConfig:
84
85
  rotary_percentage=0.0,
85
86
  qkv_use_bias=True,
86
87
  qkv_transpose_before_split=True,
88
+ qkv_fused_interleaved=False,
87
89
  output_proj_use_bias=True,
88
90
  enable_kv_cache=False,
89
91
  )
@@ -13,8 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import argparse
16
17
  import os
17
18
  from pathlib import Path
19
+ from typing import Optional
18
20
 
19
21
  import torch
20
22
 
@@ -24,14 +26,36 @@ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
26
  import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
25
27
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
28
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
- import ai_edge_torch.generative.utilities.loader as loading_utils
28
29
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
29
30
 
31
+ arg_parser = argparse.ArgumentParser()
32
+ arg_parser.add_argument(
33
+ '--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
34
+ )
35
+ arg_parser.add_argument(
36
+ '--diffusion_ckpt',
37
+ type=str,
38
+ help='Path to source diffusion model checkpoint',
39
+ required=True,
40
+ )
41
+ arg_parser.add_argument(
42
+ '--decoder_ckpt',
43
+ type=str,
44
+ help='Path to source image decoder model checkpoint',
45
+ required=True,
46
+ )
47
+ arg_parser.add_argument(
48
+ '--output_dir',
49
+ type=str,
50
+ help='Path to the converted TF Lite directory.',
51
+ required=True,
52
+ )
53
+
30
54
 
31
55
  @torch.inference_mode
32
56
  def convert_stable_diffusion_to_tflite(
57
+ output_dir: str,
33
58
  clip_ckpt_path: str,
34
- encoder_ckpt_path: str,
35
59
  diffusion_ckpt_path: str,
36
60
  decoder_ckpt_path: str,
37
61
  image_height: int = 512,
@@ -39,23 +63,28 @@ def convert_stable_diffusion_to_tflite(
39
63
  ):
40
64
 
41
65
  clip_model = clip.CLIP(clip.get_model_config())
42
- loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
66
+ loader = stable_diffusion_loader.ClipModelLoader(
67
+ clip_ckpt_path,
68
+ clip.TENSOR_NAMES,
69
+ )
43
70
  loader.load(clip_model, strict=False)
44
71
 
45
- encoder = Encoder()
46
- encoder.load_state_dict(torch.load(encoder_ckpt_path))
47
-
48
72
  diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
49
73
  diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
50
- diffusion_ckpt_path, diffusion.TENSORS_NAMES
74
+ diffusion_ckpt_path, diffusion.TENSOR_NAMES
51
75
  )
52
- diffusion_loader.load(diffusion_model)
76
+ diffusion_loader.load(diffusion_model, strict=False)
53
77
 
54
78
  decoder_model = decoder.Decoder(decoder.get_model_config())
55
79
  decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
56
- decoder_ckpt_path, decoder.TENSORS_NAMES
80
+ decoder_ckpt_path, decoder.TENSOR_NAMES
57
81
  )
58
- decoder_loader.load(decoder_model)
82
+ decoder_loader.load(decoder_model, strict=False)
83
+
84
+ # TODO(yichunk): enable image encoder conversion
85
+ # if encoder_ckpt_path is not None:
86
+ # encoder = Encoder()
87
+ # encoder.load_state_dict(torch.load(encoder_ckpt_path))
59
88
 
60
89
  # Tensors used to trace the model graph during conversion.
61
90
  n_tokens = 77
@@ -67,50 +96,47 @@ def convert_stable_diffusion_to_tflite(
67
96
  (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
68
97
  )
69
98
 
70
- input_latents = encoder(input_image, noise)
99
+ input_latents = torch.zeros_like(noise)
71
100
  context_cond = clip_model(prompt_tokens)
72
101
  context_uncond = torch.zeros_like(context_cond)
73
102
  context = torch.cat([context_cond, context_uncond], axis=0)
74
103
  time_embedding = util.get_time_embedding(timestamp)
75
104
 
105
+ if not os.path.exists(output_dir):
106
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
107
+
108
+ # TODO(yichunk): convert to multi signature tflite model.
76
109
  # CLIP text encoder
77
110
  ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
78
- '/tmp/stable_diffusion/clip.tflite'
111
+ f'{output_dir}/clip.tflite'
79
112
  )
80
113
 
81
- # TODO(yichunk): convert to multi signature tflite model.
114
+ # TODO(yichunk): enable image encoder conversion
82
115
  # Image encoder
83
- ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
84
- '/tmp/stable_diffusion/encoder.tflite'
85
- )
116
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
117
+ # f'{output_dir}/encoder.tflite'
118
+ # )
86
119
 
87
120
  # Diffusion
88
121
  ai_edge_torch.signature(
89
122
  'diffusion',
90
123
  diffusion_model,
91
124
  (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
92
- ).convert().export('/tmp/stable_diffusion/diffusion.tflite')
125
+ ).convert().export(f'{output_dir}/diffusion.tflite')
93
126
 
94
127
  # Image decoder
95
128
  ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
96
- '/tmp/stable_diffusion/decoder.tflite'
129
+ f'{output_dir}/decoder.tflite'
97
130
  )
98
131
 
99
132
 
100
133
  if __name__ == '__main__':
134
+ args = arg_parser.parse_args()
101
135
  convert_stable_diffusion_to_tflite(
102
- clip_ckpt_path=os.path.join(
103
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/clip.pt'
104
- ),
105
- encoder_ckpt_path=os.path.join(
106
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/encoder.pt'
107
- ),
108
- diffusion_ckpt_path=os.path.join(
109
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/diffusion.pt'
110
- ),
111
- decoder_ckpt_path=os.path.join(
112
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/decoder.pt'
113
- ),
136
+ output_dir=args.output_dir,
137
+ clip_ckpt_path=args.clip_ckpt,
138
+ diffusion_ckpt_path=args.diffusion_ckpt,
139
+ decoder_ckpt_path=args.decoder_ckpt,
114
140
  image_height=512,
115
141
  image_width=512,
116
142
  )