ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/test/loader_test.py +5 -4
  69. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  70. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  71. ai_edge_torch/generative/test/test_quantize.py +45 -48
  72. ai_edge_torch/generative/utilities/loader.py +55 -28
  73. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  74. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  75. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  76. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  79. ai_edge_torch/model.py +8 -5
  80. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  81. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  82. ai_edge_torch/quantize/quant_config.py +6 -2
  83. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  84. ai_edge_torch/version.py +16 -0
  85. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
  87. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,6 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Optional
20
20
 
21
- import torch
22
-
23
21
  import ai_edge_torch
24
22
  import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
25
23
  import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
@@ -28,10 +26,14 @@ from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
28
26
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
29
27
  from ai_edge_torch.generative.quantize import quant_recipes
30
28
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
29
+ import torch
31
30
 
32
31
  arg_parser = argparse.ArgumentParser()
33
32
  arg_parser.add_argument(
34
- '--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
33
+ '--clip_ckpt',
34
+ type=str,
35
+ help='Path to source CLIP model checkpoint',
36
+ required=True,
35
37
  )
36
38
  arg_parser.add_argument(
37
39
  '--diffusion_ckpt',
@@ -93,9 +95,13 @@ def convert_stable_diffusion_to_tflite(
93
95
  timestamp = 0
94
96
  len_prompt = 1
95
97
  prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
96
- input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
98
+ input_image = torch.full(
99
+ (1, 3, image_height, image_width), 0, dtype=torch.float32
100
+ )
97
101
  noise = torch.full(
98
- (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
102
+ (len_prompt, 4, image_height // 8, image_width // 8),
103
+ 0,
104
+ dtype=torch.float32,
99
105
  )
100
106
 
101
107
  input_latents = torch.zeros_like(noise)
@@ -107,7 +113,9 @@ def convert_stable_diffusion_to_tflite(
107
113
  if not os.path.exists(output_dir):
108
114
  Path(output_dir).mkdir(parents=True, exist_ok=True)
109
115
 
110
- quant_config = quant_recipes.full_int8_weight_only_recipe() if quantize else None
116
+ quant_config = (
117
+ quant_recipes.full_int8_weight_only_recipe() if quantize else None
118
+ )
111
119
 
112
120
  # TODO(yichunk): convert to multi signature tflite model.
113
121
  # CLIP text encoder
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  import ai_edge_torch.generative.layers.builder as layers_builder
20
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
21
18
  import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
20
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
25
  post_quant_conv="first_stage_model.post_quant_conv",
@@ -104,7 +103,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
104
103
  norm_2="first_stage_model.decoder.up.1.block.0.norm2",
105
104
  conv_1="first_stage_model.decoder.up.1.block.0.conv1",
106
105
  conv_2="first_stage_model.decoder.up.1.block.0.conv2",
107
- residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
106
+ residual_layer=(
107
+ "first_stage_model.decoder.up.1.block.0.nin_shortcut"
108
+ ),
108
109
  ),
109
110
  stable_diffusion_loader.ResidualBlockTensorNames(
110
111
  norm_1="first_stage_model.decoder.up.1.block.1.norm1",
@@ -128,7 +129,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
128
129
  norm_2="first_stage_model.decoder.up.0.block.0.norm2",
129
130
  conv_1="first_stage_model.decoder.up.0.block.0.conv1",
130
131
  conv_2="first_stage_model.decoder.up.0.block.0.conv2",
131
- residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
132
+ residual_layer=(
133
+ "first_stage_model.decoder.up.0.block.0.nin_shortcut"
134
+ ),
132
135
  ),
133
136
  stable_diffusion_loader.ResidualBlockTensorNames(
134
137
  norm_1="first_stage_model.decoder.up.0.block.1.norm1",
@@ -299,7 +302,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
299
302
  mid_block_config = unet_cfg.MidBlock2DConfig(
300
303
  in_channels=block_out_channels[-1],
301
304
  normalization_config=norm_config,
302
- activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
305
+ activation_config=layers_cfg.ActivationConfig(
306
+ layers_cfg.ActivationType.SILU
307
+ ),
303
308
  num_layers=1,
304
309
  attention_block_config=att_config,
305
310
  )
@@ -308,7 +313,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
308
313
  in_channels=in_channels,
309
314
  latent_channels=latent_channels,
310
315
  out_channels=out_channels,
311
- activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
316
+ activation_config=layers_cfg.ActivationConfig(
317
+ layers_cfg.ActivationType.SILU
318
+ ),
312
319
  block_out_channels=block_out_channels,
313
320
  scaling_factor=scaling_factor,
314
321
  layers_per_block=layers_per_block,
@@ -13,14 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  import ai_edge_torch.generative.layers.builder as layers_builder
20
17
  import ai_edge_torch.generative.layers.model_config as layers_cfg
21
18
  import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
19
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
20
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  _down_encoder_blocks_tensor_names = [
26
25
  stable_diffusion_loader.DownEncoderBlockTensorNames(
@@ -39,9 +38,15 @@ _down_encoder_blocks_tensor_names = [
39
38
  ],
40
39
  transformer_block_tensor_names=[
41
40
  stable_diffusion_loader.TransformerBlockTensorNames(
42
- pre_conv_norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm",
43
- conv_in=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in",
44
- conv_out=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out",
41
+ pre_conv_norm=(
42
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm"
43
+ ),
44
+ conv_in=(
45
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in"
46
+ ),
47
+ conv_out=(
48
+ f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out"
49
+ ),
45
50
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
46
51
  norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
47
52
  q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
@@ -80,7 +85,9 @@ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
80
85
  conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
81
86
  norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
82
87
  conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
83
- time_embedding=f"model.diffusion_model.middle_block.{i}.emb_layers.1",
88
+ time_embedding=(
89
+ f"model.diffusion_model.middle_block.{i}.emb_layers.1"
90
+ ),
84
91
  )
85
92
  for i in [0, 2]
86
93
  ],
@@ -117,8 +124,12 @@ _up_decoder_blocks_tensor_names = [
117
124
  stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
118
125
  residual_block_tensor_names=[
119
126
  stable_diffusion_loader.ResidualBlockTensorNames(
120
- norm_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0",
121
- conv_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2",
127
+ norm_1=(
128
+ f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0"
129
+ ),
130
+ conv_1=(
131
+ f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2"
132
+ ),
122
133
  norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
123
134
  conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
124
135
  time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
@@ -128,9 +139,15 @@ _up_decoder_blocks_tensor_names = [
128
139
  ],
129
140
  transformer_block_tensor_names=[
130
141
  stable_diffusion_loader.TransformerBlockTensorNames(
131
- pre_conv_norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.norm",
132
- conv_in=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in",
133
- conv_out=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out",
142
+ pre_conv_norm=(
143
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.norm"
144
+ ),
145
+ conv_in=(
146
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in"
147
+ ),
148
+ conv_out=(
149
+ f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out"
150
+ ),
134
151
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
135
152
  norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
136
153
  q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
@@ -157,7 +174,9 @@ _up_decoder_blocks_tensor_names = [
157
174
  else None,
158
175
  upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
159
176
  if 0 < i < 3
160
- else (f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None),
177
+ else (
178
+ f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None
179
+ ),
161
180
  )
162
181
  for i in range(4)
163
182
  ]
@@ -475,7 +494,10 @@ class Diffusion(nn.Module):
475
494
  layers_cfg.ActivationConfig(config.final_activation_type)
476
495
  )
477
496
  self.conv_out = nn.Conv2d(
478
- reversed_block_out_channels[-1], config.out_channels, kernel_size=3, padding=1
497
+ reversed_block_out_channels[-1],
498
+ config.out_channels,
499
+ kernel_size=3,
500
+ padding=1,
479
501
  )
480
502
 
481
503
  @torch.inference_mode
@@ -496,12 +518,15 @@ class Diffusion(nn.Module):
496
518
  x = self.conv_in(latents)
497
519
  skip_connection_tensors = [x]
498
520
  for encoder in self.down_encoders:
499
- x, hidden_states = encoder(x, time_emb, context, output_hidden_states=True)
521
+ x, hidden_states = encoder(
522
+ x, time_emb, context, output_hidden_states=True
523
+ )
500
524
  skip_connection_tensors.extend(hidden_states)
501
525
  x = self.mid_block(x, time_emb, context)
502
526
  for decoder in self.up_decoders:
503
527
  encoder_tensors = [
504
- skip_connection_tensors.pop() for i in range(self.config.layers_per_block + 1)
528
+ skip_connection_tensors.pop()
529
+ for i in range(self.config.layers_per_block + 1)
505
530
  ]
506
531
  x = decoder(x, encoder_tensors, time_emb, context)
507
532
  x = self.final_norm(x)
@@ -13,12 +13,11 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
16
17
  import torch
17
18
  from torch import nn
18
19
  from torch.nn import functional as F
19
20
 
20
- from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
21
-
22
21
 
23
22
  class AttentionBlock(nn.Module):
24
23
 
@@ -50,7 +49,9 @@ class ResidualBlock(nn.Module):
50
49
  self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
51
50
 
52
51
  self.groupnorm_2 = nn.GroupNorm(32, out_channels)
53
- self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
52
+ self.conv_2 = nn.Conv2d(
53
+ out_channels, out_channels, kernel_size=3, padding=1
54
+ )
54
55
 
55
56
  if in_channels == out_channels:
56
57
  self.residual_layer = nn.Identity()
@@ -18,30 +18,41 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Dict, Optional
20
20
 
21
- import numpy as np
22
- from PIL import Image
23
- from tqdm import tqdm
24
-
25
21
  import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
26
22
  from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
27
23
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
28
24
  from ai_edge_torch.model import TfLiteModel
25
+ import numpy as np
26
+ from PIL import Image
27
+ from tqdm import tqdm
29
28
 
30
29
  arg_parser = argparse.ArgumentParser()
31
30
  arg_parser.add_argument(
32
31
  '--tokenizer_vocab_dir',
33
32
  type=str,
34
- help='Directory to the tokenizer vocabulary files, which include `merges.txt` and `vocab.json`',
33
+ help=(
34
+ 'Directory to the tokenizer vocabulary files, which include'
35
+ ' `merges.txt` and `vocab.json`'
36
+ ),
35
37
  required=True,
36
38
  )
37
39
  arg_parser.add_argument(
38
- '--clip_ckpt', type=str, help='Path to CLIP TFLite tflite file', required=True
40
+ '--clip_ckpt',
41
+ type=str,
42
+ help='Path to CLIP TFLite tflite file',
43
+ required=True,
39
44
  )
40
45
  arg_parser.add_argument(
41
- '--diffusion_ckpt', type=str, help='Path to diffusion tflite file', required=True
46
+ '--diffusion_ckpt',
47
+ type=str,
48
+ help='Path to diffusion tflite file',
49
+ required=True,
42
50
  )
43
51
  arg_parser.add_argument(
44
- '--decoder_ckpt', type=str, help='Path to decoder tflite file', required=True
52
+ '--decoder_ckpt',
53
+ type=str,
54
+ help='Path to decoder tflite file',
55
+ required=True,
45
56
  )
46
57
  arg_parser.add_argument(
47
58
  '--output_path',
@@ -56,20 +67,29 @@ arg_parser.add_argument(
56
67
  help='The prompt to guide the image generation.',
57
68
  )
58
69
  arg_parser.add_argument(
59
- '--n_inference_steps', default=20, type=int, help='The number of denoising steps.'
70
+ '--n_inference_steps',
71
+ default=20,
72
+ type=int,
73
+ help='The number of denoising steps.',
60
74
  )
61
75
  arg_parser.add_argument(
62
76
  '--sampler',
63
77
  default='k_euler',
64
78
  type=str,
65
79
  choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
66
- help='A sampler to be used to denoise the encoded image latents. Can be one of `k_lms, `k_euler`, or `k_euler_ancestral`.',
80
+ help=(
81
+ 'A sampler to be used to denoise the encoded image latents. Can be one'
82
+ ' of `k_lms, `k_euler`, or `k_euler_ancestral`.'
83
+ ),
67
84
  )
68
85
  arg_parser.add_argument(
69
86
  '--seed',
70
87
  default=None,
71
88
  type=int,
72
- help='A seed to make generation deterministic. A random number is used if unspecified.',
89
+ help=(
90
+ 'A seed to make generation deterministic. A random number is used if'
91
+ ' unspecified.'
92
+ ),
73
93
  )
74
94
 
75
95
 
@@ -154,7 +174,9 @@ def run_tflite_pipeline(
154
174
  elif sampler == 'k_euler':
155
175
  sampler = samplers.KEulerSampler(n_inference_steps=n_inference_steps)
156
176
  elif sampler == 'k_euler_ancestral':
157
- sampler = samplers.KEulerAncestralSampler(n_inference_steps=n_inference_steps)
177
+ sampler = samplers.KEulerAncestralSampler(
178
+ n_inference_steps=n_inference_steps
179
+ )
158
180
  else:
159
181
  raise ValueError(
160
182
  'Unknown sampler value %s. '
@@ -173,7 +195,8 @@ def run_tflite_pipeline(
173
195
  if input_image:
174
196
  if not hasattr(model, 'encoder'):
175
197
  raise AttributeError(
176
- 'Stable Diffusion must be initialized with encoder to accept input_image.'
198
+ 'Stable Diffusion must be initialized with encoder to accept'
199
+ ' input_image.'
177
200
  )
178
201
  input_image = input_image.resize((width, height))
179
202
  input_image_np = np.array(input_image).astype(np.float32)
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
-
18
16
  from ai_edge_torch.generative.examples.stable_diffusion import util
19
17
  from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
18
+ import numpy as np
20
19
 
21
20
 
22
21
  class KEulerSampler(SamplerInterface):
@@ -46,7 +45,9 @@ class KEulerSampler(SamplerInterface):
46
45
 
47
46
  def set_strength(self, strength=1):
48
47
  start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
49
- self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
48
+ self.timesteps = np.linspace(
49
+ self.n_training_steps - 1, 0, self.n_inference_steps
50
+ )
50
51
  self.timesteps = self.timesteps[start_step:]
51
52
  self.initial_scale = self.sigmas[start_step]
52
53
  self.step_count = start_step
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
-
18
16
  from ai_edge_torch.generative.examples.stable_diffusion import util
19
17
  from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
18
+ import numpy as np
20
19
 
21
20
 
22
21
  class KEulerAncestralSampler(SamplerInterface):
@@ -46,7 +45,9 @@ class KEulerAncestralSampler(SamplerInterface):
46
45
 
47
46
  def set_strength(self, strength=1):
48
47
  start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
49
- self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
48
+ self.timesteps = np.linspace(
49
+ self.n_training_steps - 1, 0, self.n_inference_steps
50
+ )
50
51
  self.timesteps = self.timesteps[start_step:]
51
52
  self.initial_scale = self.sigmas[start_step]
52
53
  self.step_count = start_step
@@ -13,10 +13,9 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import numpy as np
17
-
18
16
  from ai_edge_torch.generative.examples.stable_diffusion import util
19
17
  from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
18
+ import numpy as np
20
19
 
21
20
 
22
21
  class KLMSSampler(SamplerInterface):
@@ -48,7 +47,9 @@ class KLMSSampler(SamplerInterface):
48
47
 
49
48
  def set_strength(self, strength=1):
50
49
  start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
51
- self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps)
50
+ self.timesteps = np.linspace(
51
+ self.n_training_steps - 1, 0, self.n_inference_steps
52
+ )
52
53
  self.timesteps = self.timesteps[start_step:]
53
54
  self.initial_scale = self.sigmas[start_step]
54
55
  self.step_count = start_step
@@ -27,7 +27,10 @@ def create_bytes_table() -> dict:
27
27
  special_count = 0
28
28
  for byte in range(256):
29
29
  category = unicodedata.category(chr(byte))
30
- if category[0] not in ['C', 'Z']: # ith character is NOT control char or space
30
+ if category[0] not in [
31
+ 'C',
32
+ 'Z',
33
+ ]: # ith character is NOT control char or space
31
34
  table[byte] = chr(byte)
32
35
  else: # ith character IS control char or space
33
36
  table[byte] = chr(special_count + 256)
@@ -20,14 +20,20 @@ import torch
20
20
 
21
21
 
22
22
  def get_time_embedding(timestep):
23
- freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
23
+ freqs = torch.pow(
24
+ 10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160
25
+ )
24
26
  x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
25
27
  return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
26
28
 
27
29
 
28
- def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
30
+ def get_alphas_cumprod(
31
+ beta_start=0.00085, beta_end=0.0120, n_training_steps=1000
32
+ ):
29
33
  betas = (
30
- np.linspace(beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32)
34
+ np.linspace(
35
+ beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32
36
+ )
31
37
  ** 2
32
38
  )
33
39
  alphas = 1.0 - betas
@@ -16,12 +16,11 @@
16
16
  import os
17
17
  from pathlib import Path
18
18
 
19
- import numpy as np
20
- import torch
21
-
22
19
  import ai_edge_torch
23
20
  from ai_edge_torch.generative.examples.t5 import t5
24
21
  from ai_edge_torch.generative.quantize import quant_recipes
22
+ import numpy as np
23
+ import torch
25
24
 
26
25
 
27
26
  # TODO(haoliang): clean this up untile 2-sig model is validated e2e.
@@ -73,8 +72,12 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
73
72
  embedding_layer = torch.nn.Embedding(
74
73
  config.vocab_size, config.embedding_dim, padding_idx=0
75
74
  )
76
- t5_encoder_model = t5.build_t5_encoder_model(config, embedding_layer, checkpoint_path)
77
- t5_decoder_model = t5.build_t5_decoder_model(config, embedding_layer, checkpoint_path)
75
+ t5_encoder_model = t5.build_t5_encoder_model(
76
+ config, embedding_layer, checkpoint_path
77
+ )
78
+ t5_decoder_model = t5.build_t5_decoder_model(
79
+ config, embedding_layer, checkpoint_path
80
+ )
78
81
 
79
82
  # encoder
80
83
  seq_len = 512