ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240802__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/convert/conversion.py +12 -8
  2. ai_edge_torch/convert/conversion_utils.py +38 -20
  3. ai_edge_torch/convert/converter.py +11 -5
  4. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  5. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  6. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +45 -36
  7. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  8. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  9. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  15. ai_edge_torch/convert/test/test_convert.py +39 -16
  16. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  17. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  18. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  19. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  20. ai_edge_torch/debug/culprit.py +41 -16
  21. ai_edge_torch/debug/test/test_culprit.py +4 -3
  22. ai_edge_torch/debug/test/test_search_model.py +4 -3
  23. ai_edge_torch/debug/utils.py +3 -1
  24. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  25. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  26. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  27. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  30. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  31. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  32. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  33. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  34. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  35. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  36. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
  37. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
  38. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
  39. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  40. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
  41. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  45. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  46. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  47. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  52. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  53. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  54. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  55. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  56. ai_edge_torch/generative/layers/attention.py +19 -11
  57. ai_edge_torch/generative/layers/builder.py +3 -4
  58. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  59. ai_edge_torch/generative/layers/model_config.py +6 -2
  60. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  61. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  62. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  63. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  64. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  65. ai_edge_torch/generative/quantize/example.py +2 -3
  66. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  67. ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
  68. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  69. ai_edge_torch/generative/test/loader_test.py +5 -4
  70. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  71. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  72. ai_edge_torch/generative/test/test_quantize.py +45 -47
  73. ai_edge_torch/generative/utilities/loader.py +55 -28
  74. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  75. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  76. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  77. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  78. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  79. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  80. ai_edge_torch/model.py +8 -5
  81. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  83. ai_edge_torch/quantize/quant_config.py +6 -2
  84. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  85. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/RECORD +89 -89
  87. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/top_level.txt +0 -0
@@ -18,19 +18,22 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Optional
20
20
 
21
- import torch
22
-
23
21
  import ai_edge_torch
24
22
  import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
25
23
  import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
26
24
  import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
27
25
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
28
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
+ from ai_edge_torch.generative.quantize import quant_recipes
29
28
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
29
+ import torch
30
30
 
31
31
  arg_parser = argparse.ArgumentParser()
32
32
  arg_parser.add_argument(
33
- '--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
33
+ '--clip_ckpt',
34
+ type=str,
35
+ help='Path to source CLIP model checkpoint',
36
+ required=True,
34
37
  )
35
38
  arg_parser.add_argument(
36
39
  '--diffusion_ckpt',
@@ -60,6 +63,7 @@ def convert_stable_diffusion_to_tflite(
60
63
  decoder_ckpt_path: str,
61
64
  image_height: int = 512,
62
65
  image_width: int = 512,
66
+ quantize: bool = True,
63
67
  ):
64
68
 
65
69
  clip_model = clip.CLIP(clip.get_model_config())
@@ -91,9 +95,13 @@ def convert_stable_diffusion_to_tflite(
91
95
  timestamp = 0
92
96
  len_prompt = 1
93
97
  prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
94
- input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
98
+ input_image = torch.full(
99
+ (1, 3, image_height, image_width), 0, dtype=torch.float32
100
+ )
95
101
  noise = torch.full(
96
- (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
102
+ (len_prompt, 4, image_height // 8, image_width // 8),
103
+ 0,
104
+ dtype=torch.float32,
97
105
  )
98
106
 
99
107
  input_latents = torch.zeros_like(noise)
@@ -105,15 +113,19 @@ def convert_stable_diffusion_to_tflite(
105
113
  if not os.path.exists(output_dir):
106
114
  Path(output_dir).mkdir(parents=True, exist_ok=True)
107
115
 
116
+ quant_config = (
117
+ quant_recipes.full_int8_weight_only_recipe() if quantize else None
118
+ )
119
+
108
120
  # TODO(yichunk): convert to multi signature tflite model.
109
121
  # CLIP text encoder
110
- ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
111
- f'{output_dir}/clip.tflite'
112
- )
122
+ ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
123
+ quant_config=quant_config
124
+ ).export(f'{output_dir}/clip.tflite')
113
125
 
114
126
  # TODO(yichunk): enable image encoder conversion
115
127
  # Image encoder
116
- # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
128
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config).export(
117
129
  # f'{output_dir}/encoder.tflite'
118
130
  # )
119
131
 
@@ -122,12 +134,12 @@ def convert_stable_diffusion_to_tflite(
122
134
  'diffusion',
123
135
  diffusion_model,
124
136
  (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
125
- ).convert().export(f'{output_dir}/diffusion.tflite')
137
+ ).convert(quant_config=quant_config).export(f'{output_dir}/diffusion.tflite')
126
138
 
127
139
  # Image decoder
128
- ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
129
- f'{output_dir}/decoder.tflite'
130
- )
140
+ ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
141
+ quant_config=quant_config
142
+ ).export(f'{output_dir}/decoder.tflite')
131
143
 
132
144
 
133
145
  if __name__ == '__main__':
@@ -139,4 +151,5 @@ if __name__ == '__main__':
139
151
  decoder_ckpt_path=args.decoder_ckpt,
140
152
  image_height=512,
141
153
  image_width=512,
154
+ quantize=True,
142
155
  )
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  import ai_edge_torch.generative.layers.builder as layers_builder
20
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
21
18
  import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
20
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
25
  post_quant_conv="first_stage_model.post_quant_conv",
@@ -104,7 +103,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
104
103
  norm_2="first_stage_model.decoder.up.1.block.0.norm2",
105
104
  conv_1="first_stage_model.decoder.up.1.block.0.conv1",
106
105
  conv_2="first_stage_model.decoder.up.1.block.0.conv2",
107
- residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
106
+ residual_layer=(
107
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut"
108
+ ),
108
109
  ),
109
110
  stable_diffusion_loader.ResidualBlockTensorNames(
110
111
  norm_1="first_stage_model.decoder.up.1.block.1.norm1",
@@ -128,7 +129,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
128
129
  norm_2="first_stage_model.decoder.up.0.block.0.norm2",
129
130
  conv_1="first_stage_model.decoder.up.0.block.0.conv1",
130
131
  conv_2="first_stage_model.decoder.up.0.block.0.conv2",
131
- residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
132
+ residual_layer=(
133
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut"
134
+ ),
132
135
  ),
133
136
  stable_diffusion_loader.ResidualBlockTensorNames(
134
137
  norm_1="first_stage_model.decoder.up.0.block.1.norm1",
@@ -293,12 +296,15 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
293
296
  qkv_fused_interleaved=False,
294
297
  rotary_percentage=0.0,
295
298
  ),
299
+ enable_hlfb=False,
296
300
  )
297
301
 
298
302
  mid_block_config = unet_cfg.MidBlock2DConfig(
299
303
  in_channels=block_out_channels[-1],
300
304
  normalization_config=norm_config,
301
- activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
305
+ activation_config=layers_cfg.ActivationConfig(
306
+ layers_cfg.ActivationType.SILU
307
+ ),
302
308
  num_layers=1,
303
309
  attention_block_config=att_config,
304
310
  )
@@ -307,7 +313,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
307
313
  in_channels=in_channels,
308
314
  latent_channels=latent_channels,
309
315
  out_channels=out_channels,
310
- activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
316
+ activation_config=layers_cfg.ActivationConfig(
317
+ layers_cfg.ActivationType.SILU
318
+ ),
311
319
  block_out_channels=block_out_channels,
312
320
  scaling_factor=scaling_factor,
313
321
  layers_per_block=layers_per_block,
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  import ai_edge_torch.generative.layers.builder as layers_builder
20
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
21
18
  import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
20
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  _down_encoder_blocks_tensor_names = [
26
25
  stable_diffusion_loader.DownEncoderBlockTensorNames(
@@ -39,9 +38,15 @@ _down_encoder_blocks_tensor_names = [
39
38
  ],
40
39
  transformer_block_tensor_names=[
41
40
  stable_diffusion_loader.TransformerBlockTensorNames(
42
- pre_conv_norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm",
43
- conv_in=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in",
44
- conv_out=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out",
41
+ pre_conv_norm=(
42
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm"
43
+ ),
44
+ conv_in=(
45
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in"
46
+ ),
47
+ conv_out=(
48
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out"
49
+ ),
45
50
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
46
51
  norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
47
52
  q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
@@ -80,7 +85,9 @@ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
80
85
  conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
81
86
  norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
82
87
  conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
83
- time_embedding=f"model.diffusion_model.middle_block.{i}.emb_layers.1",
88
+ time_embedding=(
89
+ f"model.diffusion_model.middle_block.{i}.emb_layers.1"
90
+ ),
84
91
  )
85
92
  for i in [0, 2]
86
93
  ],
@@ -117,8 +124,12 @@ _up_decoder_blocks_tensor_names = [
117
124
  stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
118
125
  residual_block_tensor_names=[
119
126
  stable_diffusion_loader.ResidualBlockTensorNames(
120
- norm_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0",
121
- conv_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2",
127
+ norm_1=(
128
+ f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0"
129
+ ),
130
+ conv_1=(
131
+ f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2"
132
+ ),
122
133
  norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
123
134
  conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
124
135
  time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
@@ -128,9 +139,15 @@ _up_decoder_blocks_tensor_names = [
128
139
  ],
129
140
  transformer_block_tensor_names=[
130
141
  stable_diffusion_loader.TransformerBlockTensorNames(
131
- pre_conv_norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.norm",
132
- conv_in=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in",
133
- conv_out=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out",
142
+ pre_conv_norm=(
143
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.norm"
144
+ ),
145
+ conv_in=(
146
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in"
147
+ ),
148
+ conv_out=(
149
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out"
150
+ ),
134
151
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
135
152
  norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
136
153
  q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
@@ -157,7 +174,9 @@ _up_decoder_blocks_tensor_names = [
157
174
  else None,
158
175
  upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
159
176
  if 0 < i < 3
160
- else (f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None),
177
+ else (
178
+ f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None
179
+ ),
161
180
  )
162
181
  for i in range(4)
163
182
  ]
@@ -294,6 +313,7 @@ class Diffusion(nn.Module):
294
313
  attention_batch_size=config.transformer_batch_size,
295
314
  normalization_config=config.transformer_norm_config,
296
315
  attention_config=attention_config,
316
+ enable_hlfb=False,
297
317
  ),
298
318
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
299
319
  query_dim=output_channel,
@@ -301,6 +321,7 @@ class Diffusion(nn.Module):
301
321
  attention_batch_size=config.transformer_batch_size,
302
322
  normalization_config=config.transformer_norm_config,
303
323
  attention_config=attention_config,
324
+ enable_hlfb=False,
304
325
  ),
305
326
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
306
327
  feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -354,6 +375,7 @@ class Diffusion(nn.Module):
354
375
  attention_batch_size=config.transformer_batch_size,
355
376
  normalization_config=config.transformer_norm_config,
356
377
  attention_config=attention_config,
378
+ enable_hlfb=False,
357
379
  ),
358
380
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
359
381
  query_dim=mid_block_channels,
@@ -361,6 +383,7 @@ class Diffusion(nn.Module):
361
383
  attention_batch_size=config.transformer_batch_size,
362
384
  normalization_config=config.transformer_norm_config,
363
385
  attention_config=attention_config,
386
+ enable_hlfb=False,
364
387
  ),
365
388
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
366
389
  feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -415,6 +438,7 @@ class Diffusion(nn.Module):
415
438
  attention_batch_size=config.transformer_batch_size,
416
439
  normalization_config=config.transformer_norm_config,
417
440
  attention_config=attention_config,
441
+ enable_hlfb=False,
418
442
  ),
419
443
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
420
444
  query_dim=output_channel,
@@ -422,6 +446,7 @@ class Diffusion(nn.Module):
422
446
  attention_batch_size=config.transformer_batch_size,
423
447
  normalization_config=config.transformer_norm_config,
424
448
  attention_config=attention_config,
449
+ enable_hlfb=False,
425
450
  ),
426
451
  pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
427
452
  feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
@@ -469,7 +494,10 @@ class Diffusion(nn.Module):
469
494
  layers_cfg.ActivationConfig(config.final_activation_type)
470
495
  )
471
496
  self.conv_out = nn.Conv2d(
472
- reversed_block_out_channels[-1], config.out_channels, kernel_size=3, padding=1
497
+ reversed_block_out_channels[-1],
498
+ config.out_channels,
499
+ kernel_size=3,
500
+ padding=1,
473
501
  )
474
502
 
475
503
  @torch.inference_mode
@@ -490,12 +518,15 @@ class Diffusion(nn.Module):
490
518
  x = self.conv_in(latents)
491
519
  skip_connection_tensors = [x]
492
520
  for encoder in self.down_encoders:
493
- x, hidden_states = encoder(x, time_emb, context, output_hidden_states=True)
521
+ x, hidden_states = encoder(
522
+ x, time_emb, context, output_hidden_states=True
523
+ )
494
524
  skip_connection_tensors.extend(hidden_states)
495
525
  x = self.mid_block(x, time_emb, context)
496
526
  for decoder in self.up_decoders:
497
527
  encoder_tensors = [
498
- skip_connection_tensors.pop() for i in range(self.config.layers_per_block + 1)
528
+ skip_connection_tensors.pop()
529
+ for i in range(self.config.layers_per_block + 1)
499
530
  ]
500
531
  x = decoder(x, encoder_tensors, time_emb, context)
501
532
  x = self.final_norm(x)
@@ -13,12 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
16
17
  import torch
17
18
  from torch import nn
18
19
  from torch.nn import functional as F
19
20
 
20
- from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
21
-
22
21
 
23
22
  class AttentionBlock(nn.Module):
24
23
 
@@ -50,7 +49,9 @@ class ResidualBlock(nn.Module):
50
49
  self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
51
50
 
52
51
  self.groupnorm_2 = nn.GroupNorm(32, out_channels)
53
- self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
52
+ self.conv_2 = nn.Conv2d(
53
+ out_channels, out_channels, kernel_size=3, padding=1
54
+ )
54
55
 
55
56
  if in_channels == out_channels:
56
57
  self.residual_layer = nn.Identity()
@@ -18,30 +18,41 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Dict, Optional
20
20
 
21
- import numpy as np
22
- from PIL import Image
23
- from tqdm import tqdm
24
-
25
21
  import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
26
22
  from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
27
23
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
28
24
  from ai_edge_torch.model import TfLiteModel
25
+ import numpy as np
26
+ from PIL import Image
27
+ from tqdm import tqdm
29
28
 
30
29
  arg_parser = argparse.ArgumentParser()
31
30
  arg_parser.add_argument(
32
31
  '--tokenizer_vocab_dir',
33
32
  type=str,
34
- help='Directory to the tokenizer vocabulary files, which include `merges.txt` and `vocab.json`',
33
+ help=(
34
+ 'Directory to the tokenizer vocabulary files, which include'
35
+ ' `merges.txt` and `vocab.json`'
36
+ ),
35
37
  required=True,
36
38
  )
37
39
  arg_parser.add_argument(
38
- '--clip_ckpt', type=str, help='Path to CLIP TFLite tflite file', required=True
40
+ '--clip_ckpt',
41
+ type=str,
42
+ help='Path to CLIP TFLite tflite file',
43
+ required=True,
39
44
  )
40
45
  arg_parser.add_argument(
41
- '--diffusion_ckpt', type=str, help='Path to diffusion tflite file', required=True
46
+ '--diffusion_ckpt',
47
+ type=str,
48
+ help='Path to diffusion tflite file',
49
+ required=True,
42
50
  )
43
51
  arg_parser.add_argument(
44
- '--decoder_ckpt', type=str, help='Path to decoder tflite file', required=True
52
+ '--decoder_ckpt',
53
+ type=str,
54
+ help='Path to decoder tflite file',
55
+ required=True,
45
56
  )
46
57
  arg_parser.add_argument(
47
58
  '--output_path',
@@ -56,14 +67,29 @@ arg_parser.add_argument(
56
67
  help='The prompt to guide the image generation.',
57
68
  )
58
69
  arg_parser.add_argument(
59
- '--n_inference_steps', default=20, type=int, help='The number of denoising steps.'
70
+ '--n_inference_steps',
71
+ default=20,
72
+ type=int,
73
+ help='The number of denoising steps.',
60
74
  )
61
75
  arg_parser.add_argument(
62
76
  '--sampler',
63
77
  default='k_euler',
64
78
  type=str,
65
79
  choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
66
- help='A sampler to be used to denoise the encoded image latents. Can be one of `k_lms, `k_euler`, or `k_euler_ancestral`.',
80
+ help=(
81
+ 'A sampler to be used to denoise the encoded image latents. Can be one'
82
+ ' of `k_lms, `k_euler`, or `k_euler_ancestral`.'
83
+ ),
84
+ )
85
+ arg_parser.add_argument(
86
+ '--seed',
87
+ default=None,
88
+ type=int,
89
+ help=(
90
+ 'A seed to make generation deterministic. A random number is used if'
91
+ ' unspecified.'
92
+ ),
67
93
  )
68
94
 
69
95
 
@@ -148,7 +174,9 @@ def run_tflite_pipeline(
148
174
  elif sampler == 'k_euler':
149
175
  sampler = samplers.KEulerSampler(n_inference_steps=n_inference_steps)
150
176
  elif sampler == 'k_euler_ancestral':
151
- sampler = samplers.KEulerAncestralSampler(n_inference_steps=n_inference_steps)
177
+ sampler = samplers.KEulerAncestralSampler(
178
+ n_inference_steps=n_inference_steps
179
+ )
152
180
  else:
153
181
  raise ValueError(
154
182
  'Unknown sampler value %s. '
@@ -167,7 +195,8 @@ def run_tflite_pipeline(
167
195
  if input_image:
168
196
  if not hasattr(model, 'encoder'):
169
197
  raise AttributeError(
170
- 'Stable Diffusion must be initialized with encoder to accept input_image.'
198
+ 'Stable Diffusion must be initialized with encoder to accept'
199
+ ' input_image.'
171
200
  )
172
201
  input_image = input_image.resize((width, height))
173
202
  input_image_np = np.array(input_image).astype(np.float32)
@@ -219,4 +248,5 @@ if __name__ == '__main__':
219
248
  output_path=args.output_path,
220
249
  sampler=args.sampler,
221
250
  n_inference_steps=args.n_inference_steps,
251
+ seed=args.seed,
222
252
  )
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
-
18
16
  from ai_edge_torch.generative.examples.stable_diffusion import util
19
17
  from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
18
+ import numpy as np
20
19
 
21
20
 
22
21
  class KEulerSampler(SamplerInterface):
@@ -46,7 +45,9 @@ class KEulerSampler(SamplerInterface):
46
45
 
47
46
  def set_strength(self, strength=1):
48
47
  start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
49
- self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
48
+ self.timesteps = np.linspace(
49
+ self.n_training_steps - 1, 0, self.n_inference_steps
50
+ )
50
51
  self.timesteps = self.timesteps[start_step:]
51
52
  self.initial_scale = self.sigmas[start_step]
52
53
  self.step_count = start_step
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
-
18
16
  from ai_edge_torch.generative.examples.stable_diffusion import util
19
17
  from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
18
+ import numpy as np
20
19
 
21
20
 
22
21
  class KEulerAncestralSampler(SamplerInterface):
@@ -46,7 +45,9 @@ class KEulerAncestralSampler(SamplerInterface):
46
45
 
47
46
  def set_strength(self, strength=1):
48
47
  start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
49
- self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
48
+ self.timesteps = np.linspace(
49
+ self.n_training_steps - 1, 0, self.n_inference_steps
50
+ )
50
51
  self.timesteps = self.timesteps[start_step:]
51
52
  self.initial_scale = self.sigmas[start_step]
52
53
  self.step_count = start_step
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
-
18
16
  from ai_edge_torch.generative.examples.stable_diffusion import util
19
17
  from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
18
+ import numpy as np
20
19
 
21
20
 
22
21
  class KLMSSampler(SamplerInterface):
@@ -48,7 +47,9 @@ class KLMSSampler(SamplerInterface):
48
47
 
49
48
  def set_strength(self, strength=1):
50
49
  start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
51
- self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
50
+ self.timesteps = np.linspace(
51
+ self.n_training_steps - 1, 0, self.n_inference_steps
52
+ )
52
53
  self.timesteps = self.timesteps[start_step:]
53
54
  self.initial_scale = self.sigmas[start_step]
54
55
  self.step_count = start_step
@@ -27,7 +27,10 @@ def create_bytes_table() -> dict:
27
27
  special_count = 0
28
28
  for byte in range(256):
29
29
  category = unicodedata.category(chr(byte))
30
- if category[0] not in ['C', 'Z']: # ith character is NOT control char or space
30
+ if category[0] not in [
31
+ 'C',
32
+ 'Z',
33
+ ]: # ith character is NOT control char or space
31
34
  table[byte] = chr(byte)
32
35
  else: # ith character IS control char or space
33
36
  table[byte] = chr(special_count + 256)
@@ -20,14 +20,20 @@ import torch
20
20
 
21
21
 
22
22
  def get_time_embedding(timestep):
23
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
23
+ freqs = torch.pow(
24
+ 10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160
25
+ )
24
26
  x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
25
27
  return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
26
28
 
27
29
 
28
- def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
30
+ def get_alphas_cumprod(
31
+ beta_start=0.00085, beta_end=0.0120, n_training_steps=1000
32
+ ):
29
33
  betas = (
30
- np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32)
34
+ np.linspace(
35
+ beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32
36
+ )
31
37
  ** 2
32
38
  )
33
39
  alphas = 1.0 - betas
@@ -16,12 +16,11 @@
16
16
  import os
17
17
  from pathlib import Path
18
18
 
19
- import numpy as np
20
- import torch
21
-
22
19
  import ai_edge_torch
23
20
  from ai_edge_torch.generative.examples.t5 import t5
24
21
  from ai_edge_torch.generative.quantize import quant_recipes
22
+ import numpy as np
23
+ import torch
25
24
 
26
25
 
27
26
  # TODO(haoliang): clean this up untile 2-sig model is validated e2e.
@@ -73,8 +72,12 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
73
72
  embedding_layer = torch.nn.Embedding(
74
73
  config.vocab_size, config.embedding_dim, padding_idx=0
75
74
  )
76
- t5_encoder_model = t5.build_t5_encoder_model(config, embedding_layer, checkpoint_path)
77
- t5_decoder_model = t5.build_t5_decoder_model(config, embedding_layer, checkpoint_path)
75
+ t5_encoder_model = t5.build_t5_encoder_model(
76
+ config, embedding_layer, checkpoint_path
77
+ )
78
+ t5_decoder_model = t5.build_t5_decoder_model(
79
+ config, embedding_layer, checkpoint_path
80
+ )
78
81
 
79
82
  # encoder
80
83
  seq_len = 512