ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240802__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +45 -36
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -47
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/RECORD +89 -89
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/top_level.txt +0 -0
|
@@ -18,19 +18,22 @@ import os
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
import torch
|
|
22
|
-
|
|
23
21
|
import ai_edge_torch
|
|
24
22
|
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
25
23
|
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
26
24
|
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
|
|
27
25
|
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
28
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
27
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
|
29
28
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
29
|
+
import torch
|
|
30
30
|
|
|
31
31
|
arg_parser = argparse.ArgumentParser()
|
|
32
32
|
arg_parser.add_argument(
|
|
33
|
-
'--clip_ckpt',
|
|
33
|
+
'--clip_ckpt',
|
|
34
|
+
type=str,
|
|
35
|
+
help='Path to source CLIP model checkpoint',
|
|
36
|
+
required=True,
|
|
34
37
|
)
|
|
35
38
|
arg_parser.add_argument(
|
|
36
39
|
'--diffusion_ckpt',
|
|
@@ -60,6 +63,7 @@ def convert_stable_diffusion_to_tflite(
|
|
|
60
63
|
decoder_ckpt_path: str,
|
|
61
64
|
image_height: int = 512,
|
|
62
65
|
image_width: int = 512,
|
|
66
|
+
quantize: bool = True,
|
|
63
67
|
):
|
|
64
68
|
|
|
65
69
|
clip_model = clip.CLIP(clip.get_model_config())
|
|
@@ -91,9 +95,13 @@ def convert_stable_diffusion_to_tflite(
|
|
|
91
95
|
timestamp = 0
|
|
92
96
|
len_prompt = 1
|
|
93
97
|
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
|
|
94
|
-
input_image = torch.full(
|
|
98
|
+
input_image = torch.full(
|
|
99
|
+
(1, 3, image_height, image_width), 0, dtype=torch.float32
|
|
100
|
+
)
|
|
95
101
|
noise = torch.full(
|
|
96
|
-
(len_prompt, 4, image_height // 8, image_width // 8),
|
|
102
|
+
(len_prompt, 4, image_height // 8, image_width // 8),
|
|
103
|
+
0,
|
|
104
|
+
dtype=torch.float32,
|
|
97
105
|
)
|
|
98
106
|
|
|
99
107
|
input_latents = torch.zeros_like(noise)
|
|
@@ -105,15 +113,19 @@ def convert_stable_diffusion_to_tflite(
|
|
|
105
113
|
if not os.path.exists(output_dir):
|
|
106
114
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
107
115
|
|
|
116
|
+
quant_config = (
|
|
117
|
+
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
|
118
|
+
)
|
|
119
|
+
|
|
108
120
|
# TODO(yichunk): convert to multi signature tflite model.
|
|
109
121
|
# CLIP text encoder
|
|
110
|
-
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
|
|
111
|
-
|
|
112
|
-
)
|
|
122
|
+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert(
|
|
123
|
+
quant_config=quant_config
|
|
124
|
+
).export(f'{output_dir}/clip.tflite')
|
|
113
125
|
|
|
114
126
|
# TODO(yichunk): enable image encoder conversion
|
|
115
127
|
# Image encoder
|
|
116
|
-
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
|
|
128
|
+
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert(quant_config=quant_config).export(
|
|
117
129
|
# f'{output_dir}/encoder.tflite'
|
|
118
130
|
# )
|
|
119
131
|
|
|
@@ -122,12 +134,12 @@ def convert_stable_diffusion_to_tflite(
|
|
|
122
134
|
'diffusion',
|
|
123
135
|
diffusion_model,
|
|
124
136
|
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
|
|
125
|
-
).convert().export(f'{output_dir}/diffusion.tflite')
|
|
137
|
+
).convert(quant_config=quant_config).export(f'{output_dir}/diffusion.tflite')
|
|
126
138
|
|
|
127
139
|
# Image decoder
|
|
128
|
-
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
|
|
129
|
-
|
|
130
|
-
)
|
|
140
|
+
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert(
|
|
141
|
+
quant_config=quant_config
|
|
142
|
+
).export(f'{output_dir}/decoder.tflite')
|
|
131
143
|
|
|
132
144
|
|
|
133
145
|
if __name__ == '__main__':
|
|
@@ -139,4 +151,5 @@ if __name__ == '__main__':
|
|
|
139
151
|
decoder_ckpt_path=args.decoder_ckpt,
|
|
140
152
|
image_height=512,
|
|
141
153
|
image_width=512,
|
|
154
|
+
quantize=True,
|
|
142
155
|
)
|
|
@@ -13,14 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
16
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
20
17
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
18
|
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
20
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
21
|
+
import torch
|
|
22
|
+
from torch import nn
|
|
24
23
|
|
|
25
24
|
TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
25
|
post_quant_conv="first_stage_model.post_quant_conv",
|
|
@@ -104,7 +103,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
104
103
|
norm_2="first_stage_model.decoder.up.1.block.0.norm2",
|
|
105
104
|
conv_1="first_stage_model.decoder.up.1.block.0.conv1",
|
|
106
105
|
conv_2="first_stage_model.decoder.up.1.block.0.conv2",
|
|
107
|
-
residual_layer=
|
|
106
|
+
residual_layer=(
|
|
107
|
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut"
|
|
108
|
+
),
|
|
108
109
|
),
|
|
109
110
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
110
111
|
norm_1="first_stage_model.decoder.up.1.block.1.norm1",
|
|
@@ -128,7 +129,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
128
129
|
norm_2="first_stage_model.decoder.up.0.block.0.norm2",
|
|
129
130
|
conv_1="first_stage_model.decoder.up.0.block.0.conv1",
|
|
130
131
|
conv_2="first_stage_model.decoder.up.0.block.0.conv2",
|
|
131
|
-
residual_layer=
|
|
132
|
+
residual_layer=(
|
|
133
|
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut"
|
|
134
|
+
),
|
|
132
135
|
),
|
|
133
136
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
134
137
|
norm_1="first_stage_model.decoder.up.0.block.1.norm1",
|
|
@@ -293,12 +296,15 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
293
296
|
qkv_fused_interleaved=False,
|
|
294
297
|
rotary_percentage=0.0,
|
|
295
298
|
),
|
|
299
|
+
enable_hlfb=False,
|
|
296
300
|
)
|
|
297
301
|
|
|
298
302
|
mid_block_config = unet_cfg.MidBlock2DConfig(
|
|
299
303
|
in_channels=block_out_channels[-1],
|
|
300
304
|
normalization_config=norm_config,
|
|
301
|
-
activation_config=layers_cfg.ActivationConfig(
|
|
305
|
+
activation_config=layers_cfg.ActivationConfig(
|
|
306
|
+
layers_cfg.ActivationType.SILU
|
|
307
|
+
),
|
|
302
308
|
num_layers=1,
|
|
303
309
|
attention_block_config=att_config,
|
|
304
310
|
)
|
|
@@ -307,7 +313,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
307
313
|
in_channels=in_channels,
|
|
308
314
|
latent_channels=latent_channels,
|
|
309
315
|
out_channels=out_channels,
|
|
310
|
-
activation_config=layers_cfg.ActivationConfig(
|
|
316
|
+
activation_config=layers_cfg.ActivationConfig(
|
|
317
|
+
layers_cfg.ActivationType.SILU
|
|
318
|
+
),
|
|
311
319
|
block_out_channels=block_out_channels,
|
|
312
320
|
scaling_factor=scaling_factor,
|
|
313
321
|
layers_per_block=layers_per_block,
|
|
@@ -13,14 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
16
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
20
17
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
18
|
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
20
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
21
|
+
import torch
|
|
22
|
+
from torch import nn
|
|
24
23
|
|
|
25
24
|
_down_encoder_blocks_tensor_names = [
|
|
26
25
|
stable_diffusion_loader.DownEncoderBlockTensorNames(
|
|
@@ -39,9 +38,15 @@ _down_encoder_blocks_tensor_names = [
|
|
|
39
38
|
],
|
|
40
39
|
transformer_block_tensor_names=[
|
|
41
40
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
42
|
-
pre_conv_norm=
|
|
43
|
-
|
|
44
|
-
|
|
41
|
+
pre_conv_norm=(
|
|
42
|
+
f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm"
|
|
43
|
+
),
|
|
44
|
+
conv_in=(
|
|
45
|
+
f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in"
|
|
46
|
+
),
|
|
47
|
+
conv_out=(
|
|
48
|
+
f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out"
|
|
49
|
+
),
|
|
45
50
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
46
51
|
norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
|
|
47
52
|
q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
|
|
@@ -80,7 +85,9 @@ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
|
|
|
80
85
|
conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
|
|
81
86
|
norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
|
|
82
87
|
conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
|
|
83
|
-
time_embedding=
|
|
88
|
+
time_embedding=(
|
|
89
|
+
f"model.diffusion_model.middle_block.{i}.emb_layers.1"
|
|
90
|
+
),
|
|
84
91
|
)
|
|
85
92
|
for i in [0, 2]
|
|
86
93
|
],
|
|
@@ -117,8 +124,12 @@ _up_decoder_blocks_tensor_names = [
|
|
|
117
124
|
stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
|
|
118
125
|
residual_block_tensor_names=[
|
|
119
126
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
120
|
-
norm_1=
|
|
121
|
-
|
|
127
|
+
norm_1=(
|
|
128
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0"
|
|
129
|
+
),
|
|
130
|
+
conv_1=(
|
|
131
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2"
|
|
132
|
+
),
|
|
122
133
|
norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
|
|
123
134
|
conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
|
|
124
135
|
time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
|
|
@@ -128,9 +139,15 @@ _up_decoder_blocks_tensor_names = [
|
|
|
128
139
|
],
|
|
129
140
|
transformer_block_tensor_names=[
|
|
130
141
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
131
|
-
pre_conv_norm=
|
|
132
|
-
|
|
133
|
-
|
|
142
|
+
pre_conv_norm=(
|
|
143
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.1.norm"
|
|
144
|
+
),
|
|
145
|
+
conv_in=(
|
|
146
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in"
|
|
147
|
+
),
|
|
148
|
+
conv_out=(
|
|
149
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out"
|
|
150
|
+
),
|
|
134
151
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
135
152
|
norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
|
|
136
153
|
q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
|
|
@@ -157,7 +174,9 @@ _up_decoder_blocks_tensor_names = [
|
|
|
157
174
|
else None,
|
|
158
175
|
upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
|
|
159
176
|
if 0 < i < 3
|
|
160
|
-
else (
|
|
177
|
+
else (
|
|
178
|
+
f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None
|
|
179
|
+
),
|
|
161
180
|
)
|
|
162
181
|
for i in range(4)
|
|
163
182
|
]
|
|
@@ -294,6 +313,7 @@ class Diffusion(nn.Module):
|
|
|
294
313
|
attention_batch_size=config.transformer_batch_size,
|
|
295
314
|
normalization_config=config.transformer_norm_config,
|
|
296
315
|
attention_config=attention_config,
|
|
316
|
+
enable_hlfb=False,
|
|
297
317
|
),
|
|
298
318
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
299
319
|
query_dim=output_channel,
|
|
@@ -301,6 +321,7 @@ class Diffusion(nn.Module):
|
|
|
301
321
|
attention_batch_size=config.transformer_batch_size,
|
|
302
322
|
normalization_config=config.transformer_norm_config,
|
|
303
323
|
attention_config=attention_config,
|
|
324
|
+
enable_hlfb=False,
|
|
304
325
|
),
|
|
305
326
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
306
327
|
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
|
|
@@ -354,6 +375,7 @@ class Diffusion(nn.Module):
|
|
|
354
375
|
attention_batch_size=config.transformer_batch_size,
|
|
355
376
|
normalization_config=config.transformer_norm_config,
|
|
356
377
|
attention_config=attention_config,
|
|
378
|
+
enable_hlfb=False,
|
|
357
379
|
),
|
|
358
380
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
359
381
|
query_dim=mid_block_channels,
|
|
@@ -361,6 +383,7 @@ class Diffusion(nn.Module):
|
|
|
361
383
|
attention_batch_size=config.transformer_batch_size,
|
|
362
384
|
normalization_config=config.transformer_norm_config,
|
|
363
385
|
attention_config=attention_config,
|
|
386
|
+
enable_hlfb=False,
|
|
364
387
|
),
|
|
365
388
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
366
389
|
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
|
|
@@ -415,6 +438,7 @@ class Diffusion(nn.Module):
|
|
|
415
438
|
attention_batch_size=config.transformer_batch_size,
|
|
416
439
|
normalization_config=config.transformer_norm_config,
|
|
417
440
|
attention_config=attention_config,
|
|
441
|
+
enable_hlfb=False,
|
|
418
442
|
),
|
|
419
443
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
|
420
444
|
query_dim=output_channel,
|
|
@@ -422,6 +446,7 @@ class Diffusion(nn.Module):
|
|
|
422
446
|
attention_batch_size=config.transformer_batch_size,
|
|
423
447
|
normalization_config=config.transformer_norm_config,
|
|
424
448
|
attention_config=attention_config,
|
|
449
|
+
enable_hlfb=False,
|
|
425
450
|
),
|
|
426
451
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
427
452
|
feed_forward_block_config=unet_cfg.FeedForwardBlock2DConfig(
|
|
@@ -469,7 +494,10 @@ class Diffusion(nn.Module):
|
|
|
469
494
|
layers_cfg.ActivationConfig(config.final_activation_type)
|
|
470
495
|
)
|
|
471
496
|
self.conv_out = nn.Conv2d(
|
|
472
|
-
reversed_block_out_channels[-1],
|
|
497
|
+
reversed_block_out_channels[-1],
|
|
498
|
+
config.out_channels,
|
|
499
|
+
kernel_size=3,
|
|
500
|
+
padding=1,
|
|
473
501
|
)
|
|
474
502
|
|
|
475
503
|
@torch.inference_mode
|
|
@@ -490,12 +518,15 @@ class Diffusion(nn.Module):
|
|
|
490
518
|
x = self.conv_in(latents)
|
|
491
519
|
skip_connection_tensors = [x]
|
|
492
520
|
for encoder in self.down_encoders:
|
|
493
|
-
x, hidden_states = encoder(
|
|
521
|
+
x, hidden_states = encoder(
|
|
522
|
+
x, time_emb, context, output_hidden_states=True
|
|
523
|
+
)
|
|
494
524
|
skip_connection_tensors.extend(hidden_states)
|
|
495
525
|
x = self.mid_block(x, time_emb, context)
|
|
496
526
|
for decoder in self.up_decoders:
|
|
497
527
|
encoder_tensors = [
|
|
498
|
-
skip_connection_tensors.pop()
|
|
528
|
+
skip_connection_tensors.pop()
|
|
529
|
+
for i in range(self.config.layers_per_block + 1)
|
|
499
530
|
]
|
|
500
531
|
x = decoder(x, encoder_tensors, time_emb, context)
|
|
501
532
|
x = self.final_norm(x)
|
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
16
17
|
import torch
|
|
17
18
|
from torch import nn
|
|
18
19
|
from torch.nn import functional as F
|
|
19
20
|
|
|
20
|
-
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
21
|
-
|
|
22
21
|
|
|
23
22
|
class AttentionBlock(nn.Module):
|
|
24
23
|
|
|
@@ -50,7 +49,9 @@ class ResidualBlock(nn.Module):
|
|
|
50
49
|
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
51
50
|
|
|
52
51
|
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
|
53
|
-
self.conv_2 = nn.Conv2d(
|
|
52
|
+
self.conv_2 = nn.Conv2d(
|
|
53
|
+
out_channels, out_channels, kernel_size=3, padding=1
|
|
54
|
+
)
|
|
54
55
|
|
|
55
56
|
if in_channels == out_channels:
|
|
56
57
|
self.residual_layer = nn.Identity()
|
|
@@ -18,30 +18,41 @@ import os
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from typing import Dict, Optional
|
|
20
20
|
|
|
21
|
-
import numpy as np
|
|
22
|
-
from PIL import Image
|
|
23
|
-
from tqdm import tqdm
|
|
24
|
-
|
|
25
21
|
import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
|
|
26
22
|
from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
|
|
27
23
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
28
24
|
from ai_edge_torch.model import TfLiteModel
|
|
25
|
+
import numpy as np
|
|
26
|
+
from PIL import Image
|
|
27
|
+
from tqdm import tqdm
|
|
29
28
|
|
|
30
29
|
arg_parser = argparse.ArgumentParser()
|
|
31
30
|
arg_parser.add_argument(
|
|
32
31
|
'--tokenizer_vocab_dir',
|
|
33
32
|
type=str,
|
|
34
|
-
help=
|
|
33
|
+
help=(
|
|
34
|
+
'Directory to the tokenizer vocabulary files, which include'
|
|
35
|
+
' `merges.txt` and `vocab.json`'
|
|
36
|
+
),
|
|
35
37
|
required=True,
|
|
36
38
|
)
|
|
37
39
|
arg_parser.add_argument(
|
|
38
|
-
'--clip_ckpt',
|
|
40
|
+
'--clip_ckpt',
|
|
41
|
+
type=str,
|
|
42
|
+
help='Path to CLIP TFLite tflite file',
|
|
43
|
+
required=True,
|
|
39
44
|
)
|
|
40
45
|
arg_parser.add_argument(
|
|
41
|
-
'--diffusion_ckpt',
|
|
46
|
+
'--diffusion_ckpt',
|
|
47
|
+
type=str,
|
|
48
|
+
help='Path to diffusion tflite file',
|
|
49
|
+
required=True,
|
|
42
50
|
)
|
|
43
51
|
arg_parser.add_argument(
|
|
44
|
-
'--decoder_ckpt',
|
|
52
|
+
'--decoder_ckpt',
|
|
53
|
+
type=str,
|
|
54
|
+
help='Path to decoder tflite file',
|
|
55
|
+
required=True,
|
|
45
56
|
)
|
|
46
57
|
arg_parser.add_argument(
|
|
47
58
|
'--output_path',
|
|
@@ -56,14 +67,29 @@ arg_parser.add_argument(
|
|
|
56
67
|
help='The prompt to guide the image generation.',
|
|
57
68
|
)
|
|
58
69
|
arg_parser.add_argument(
|
|
59
|
-
'--n_inference_steps',
|
|
70
|
+
'--n_inference_steps',
|
|
71
|
+
default=20,
|
|
72
|
+
type=int,
|
|
73
|
+
help='The number of denoising steps.',
|
|
60
74
|
)
|
|
61
75
|
arg_parser.add_argument(
|
|
62
76
|
'--sampler',
|
|
63
77
|
default='k_euler',
|
|
64
78
|
type=str,
|
|
65
79
|
choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
|
|
66
|
-
help=
|
|
80
|
+
help=(
|
|
81
|
+
'A sampler to be used to denoise the encoded image latents. Can be one'
|
|
82
|
+
' of `k_lms, `k_euler`, or `k_euler_ancestral`.'
|
|
83
|
+
),
|
|
84
|
+
)
|
|
85
|
+
arg_parser.add_argument(
|
|
86
|
+
'--seed',
|
|
87
|
+
default=None,
|
|
88
|
+
type=int,
|
|
89
|
+
help=(
|
|
90
|
+
'A seed to make generation deterministic. A random number is used if'
|
|
91
|
+
' unspecified.'
|
|
92
|
+
),
|
|
67
93
|
)
|
|
68
94
|
|
|
69
95
|
|
|
@@ -148,7 +174,9 @@ def run_tflite_pipeline(
|
|
|
148
174
|
elif sampler == 'k_euler':
|
|
149
175
|
sampler = samplers.KEulerSampler(n_inference_steps=n_inference_steps)
|
|
150
176
|
elif sampler == 'k_euler_ancestral':
|
|
151
|
-
sampler = samplers.KEulerAncestralSampler(
|
|
177
|
+
sampler = samplers.KEulerAncestralSampler(
|
|
178
|
+
n_inference_steps=n_inference_steps
|
|
179
|
+
)
|
|
152
180
|
else:
|
|
153
181
|
raise ValueError(
|
|
154
182
|
'Unknown sampler value %s. '
|
|
@@ -167,7 +195,8 @@ def run_tflite_pipeline(
|
|
|
167
195
|
if input_image:
|
|
168
196
|
if not hasattr(model, 'encoder'):
|
|
169
197
|
raise AttributeError(
|
|
170
|
-
'Stable Diffusion must be initialized with encoder to accept
|
|
198
|
+
'Stable Diffusion must be initialized with encoder to accept'
|
|
199
|
+
' input_image.'
|
|
171
200
|
)
|
|
172
201
|
input_image = input_image.resize((width, height))
|
|
173
202
|
input_image_np = np.array(input_image).astype(np.float32)
|
|
@@ -219,4 +248,5 @@ if __name__ == '__main__':
|
|
|
219
248
|
output_path=args.output_path,
|
|
220
249
|
sampler=args.sampler,
|
|
221
250
|
n_inference_steps=args.n_inference_steps,
|
|
251
|
+
seed=args.seed,
|
|
222
252
|
)
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
19
17
|
from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
|
|
18
|
+
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class KEulerSampler(SamplerInterface):
|
|
@@ -46,7 +45,9 @@ class KEulerSampler(SamplerInterface):
|
|
|
46
45
|
|
|
47
46
|
def set_strength(self, strength=1):
|
|
48
47
|
start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
|
|
49
|
-
self.timesteps = np.linspace(
|
|
48
|
+
self.timesteps = np.linspace(
|
|
49
|
+
self.n_training_steps - 1, 0, self.n_inference_steps
|
|
50
|
+
)
|
|
50
51
|
self.timesteps = self.timesteps[start_step:]
|
|
51
52
|
self.initial_scale = self.sigmas[start_step]
|
|
52
53
|
self.step_count = start_step
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
19
17
|
from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
|
|
18
|
+
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class KEulerAncestralSampler(SamplerInterface):
|
|
@@ -46,7 +45,9 @@ class KEulerAncestralSampler(SamplerInterface):
|
|
|
46
45
|
|
|
47
46
|
def set_strength(self, strength=1):
|
|
48
47
|
start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
|
|
49
|
-
self.timesteps = np.linspace(
|
|
48
|
+
self.timesteps = np.linspace(
|
|
49
|
+
self.n_training_steps - 1, 0, self.n_inference_steps
|
|
50
|
+
)
|
|
50
51
|
self.timesteps = self.timesteps[start_step:]
|
|
51
52
|
self.initial_scale = self.sigmas[start_step]
|
|
52
53
|
self.step_count = start_step
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
19
17
|
from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
|
|
18
|
+
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class KLMSSampler(SamplerInterface):
|
|
@@ -48,7 +47,9 @@ class KLMSSampler(SamplerInterface):
|
|
|
48
47
|
|
|
49
48
|
def set_strength(self, strength=1):
|
|
50
49
|
start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
|
|
51
|
-
self.timesteps = np.linspace(
|
|
50
|
+
self.timesteps = np.linspace(
|
|
51
|
+
self.n_training_steps - 1, 0, self.n_inference_steps
|
|
52
|
+
)
|
|
52
53
|
self.timesteps = self.timesteps[start_step:]
|
|
53
54
|
self.initial_scale = self.sigmas[start_step]
|
|
54
55
|
self.step_count = start_step
|
|
@@ -27,7 +27,10 @@ def create_bytes_table() -> dict:
|
|
|
27
27
|
special_count = 0
|
|
28
28
|
for byte in range(256):
|
|
29
29
|
category = unicodedata.category(chr(byte))
|
|
30
|
-
if category[0] not in [
|
|
30
|
+
if category[0] not in [
|
|
31
|
+
'C',
|
|
32
|
+
'Z',
|
|
33
|
+
]: # ith character is NOT control char or space
|
|
31
34
|
table[byte] = chr(byte)
|
|
32
35
|
else: # ith character IS control char or space
|
|
33
36
|
table[byte] = chr(special_count + 256)
|
|
@@ -20,14 +20,20 @@ import torch
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def get_time_embedding(timestep):
|
|
23
|
-
freqs = torch.pow(
|
|
23
|
+
freqs = torch.pow(
|
|
24
|
+
10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160
|
|
25
|
+
)
|
|
24
26
|
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
|
25
27
|
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
def get_alphas_cumprod(
|
|
30
|
+
def get_alphas_cumprod(
|
|
31
|
+
beta_start=0.00085, beta_end=0.0120, n_training_steps=1000
|
|
32
|
+
):
|
|
29
33
|
betas = (
|
|
30
|
-
np.linspace(
|
|
34
|
+
np.linspace(
|
|
35
|
+
beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32
|
|
36
|
+
)
|
|
31
37
|
** 2
|
|
32
38
|
)
|
|
33
39
|
alphas = 1.0 - betas
|
|
@@ -16,12 +16,11 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from pathlib import Path
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
|
-
import torch
|
|
21
|
-
|
|
22
19
|
import ai_edge_torch
|
|
23
20
|
from ai_edge_torch.generative.examples.t5 import t5
|
|
24
21
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
# TODO(haoliang): clean this up untile 2-sig model is validated e2e.
|
|
@@ -73,8 +72,12 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
|
73
72
|
embedding_layer = torch.nn.Embedding(
|
|
74
73
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
75
74
|
)
|
|
76
|
-
t5_encoder_model = t5.build_t5_encoder_model(
|
|
77
|
-
|
|
75
|
+
t5_encoder_model = t5.build_t5_encoder_model(
|
|
76
|
+
config, embedding_layer, checkpoint_path
|
|
77
|
+
)
|
|
78
|
+
t5_decoder_model = t5.build_t5_decoder_model(
|
|
79
|
+
config, embedding_layer, checkpoint_path
|
|
80
|
+
)
|
|
78
81
|
|
|
79
82
|
# encoder
|
|
80
83
|
seq_len = 512
|