ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +1 -0
- ai_edge_torch/convert/conversion.py +12 -8
- ai_edge_torch/convert/conversion_utils.py +38 -20
- ai_edge_torch/convert/converter.py +11 -5
- ai_edge_torch/convert/fx_passes/__init__.py +3 -4
- ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
- ai_edge_torch/convert/test/test_convert.py +39 -16
- ai_edge_torch/convert/test/test_convert_composites.py +115 -86
- ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
- ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
- ai_edge_torch/convert/to_channel_last_io.py +6 -2
- ai_edge_torch/debug/culprit.py +41 -16
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +4 -3
- ai_edge_torch/debug/utils.py +3 -1
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
- ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
- ai_edge_torch/generative/examples/t5/t5.py +158 -125
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
- ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
- ai_edge_torch/generative/fx_passes/__init__.py +1 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
- ai_edge_torch/generative/layers/attention.py +19 -11
- ai_edge_torch/generative/layers/builder.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +4 -3
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
- ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
- ai_edge_torch/generative/quantize/example.py +2 -3
- ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
- ai_edge_torch/generative/test/loader_test.py +5 -4
- ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
- ai_edge_torch/generative/test/test_model_conversion.py +2 -3
- ai_edge_torch/generative/test/test_quantize.py +45 -48
- ai_edge_torch/generative/utilities/loader.py +55 -28
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
- ai_edge_torch/generative/utilities/t5_loader.py +77 -48
- ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
- ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
- ai_edge_torch/model.py +8 -5
- ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
- ai_edge_torch/quantize/quant_config.py +6 -2
- ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
|
@@ -18,8 +18,6 @@ import os
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from typing import Optional
|
|
20
20
|
|
|
21
|
-
import torch
|
|
22
|
-
|
|
23
21
|
import ai_edge_torch
|
|
24
22
|
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
25
23
|
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
@@ -28,10 +26,14 @@ from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
|
28
26
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
29
27
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
30
28
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
29
|
+
import torch
|
|
31
30
|
|
|
32
31
|
arg_parser = argparse.ArgumentParser()
|
|
33
32
|
arg_parser.add_argument(
|
|
34
|
-
'--clip_ckpt',
|
|
33
|
+
'--clip_ckpt',
|
|
34
|
+
type=str,
|
|
35
|
+
help='Path to source CLIP model checkpoint',
|
|
36
|
+
required=True,
|
|
35
37
|
)
|
|
36
38
|
arg_parser.add_argument(
|
|
37
39
|
'--diffusion_ckpt',
|
|
@@ -93,9 +95,13 @@ def convert_stable_diffusion_to_tflite(
|
|
|
93
95
|
timestamp = 0
|
|
94
96
|
len_prompt = 1
|
|
95
97
|
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
|
|
96
|
-
input_image = torch.full(
|
|
98
|
+
input_image = torch.full(
|
|
99
|
+
(1, 3, image_height, image_width), 0, dtype=torch.float32
|
|
100
|
+
)
|
|
97
101
|
noise = torch.full(
|
|
98
|
-
(len_prompt, 4, image_height // 8, image_width // 8),
|
|
102
|
+
(len_prompt, 4, image_height // 8, image_width // 8),
|
|
103
|
+
0,
|
|
104
|
+
dtype=torch.float32,
|
|
99
105
|
)
|
|
100
106
|
|
|
101
107
|
input_latents = torch.zeros_like(noise)
|
|
@@ -107,7 +113,9 @@ def convert_stable_diffusion_to_tflite(
|
|
|
107
113
|
if not os.path.exists(output_dir):
|
|
108
114
|
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
109
115
|
|
|
110
|
-
quant_config =
|
|
116
|
+
quant_config = (
|
|
117
|
+
quant_recipes.full_int8_weight_only_recipe() if quantize else None
|
|
118
|
+
)
|
|
111
119
|
|
|
112
120
|
# TODO(yichunk): convert to multi signature tflite model.
|
|
113
121
|
# CLIP text encoder
|
|
@@ -13,14 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
16
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
20
17
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
18
|
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
20
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
21
|
+
import torch
|
|
22
|
+
from torch import nn
|
|
24
23
|
|
|
25
24
|
TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
25
|
post_quant_conv="first_stage_model.post_quant_conv",
|
|
@@ -104,7 +103,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
104
103
|
norm_2="first_stage_model.decoder.up.1.block.0.norm2",
|
|
105
104
|
conv_1="first_stage_model.decoder.up.1.block.0.conv1",
|
|
106
105
|
conv_2="first_stage_model.decoder.up.1.block.0.conv2",
|
|
107
|
-
residual_layer=
|
|
106
|
+
residual_layer=(
|
|
107
|
+
"first_stage_model.decoder.up.1.block.0.nin_shortcut"
|
|
108
|
+
),
|
|
108
109
|
),
|
|
109
110
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
110
111
|
norm_1="first_stage_model.decoder.up.1.block.1.norm1",
|
|
@@ -128,7 +129,9 @@ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
128
129
|
norm_2="first_stage_model.decoder.up.0.block.0.norm2",
|
|
129
130
|
conv_1="first_stage_model.decoder.up.0.block.0.conv1",
|
|
130
131
|
conv_2="first_stage_model.decoder.up.0.block.0.conv2",
|
|
131
|
-
residual_layer=
|
|
132
|
+
residual_layer=(
|
|
133
|
+
"first_stage_model.decoder.up.0.block.0.nin_shortcut"
|
|
134
|
+
),
|
|
132
135
|
),
|
|
133
136
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
134
137
|
norm_1="first_stage_model.decoder.up.0.block.1.norm1",
|
|
@@ -299,7 +302,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
299
302
|
mid_block_config = unet_cfg.MidBlock2DConfig(
|
|
300
303
|
in_channels=block_out_channels[-1],
|
|
301
304
|
normalization_config=norm_config,
|
|
302
|
-
activation_config=layers_cfg.ActivationConfig(
|
|
305
|
+
activation_config=layers_cfg.ActivationConfig(
|
|
306
|
+
layers_cfg.ActivationType.SILU
|
|
307
|
+
),
|
|
303
308
|
num_layers=1,
|
|
304
309
|
attention_block_config=att_config,
|
|
305
310
|
)
|
|
@@ -308,7 +313,9 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
308
313
|
in_channels=in_channels,
|
|
309
314
|
latent_channels=latent_channels,
|
|
310
315
|
out_channels=out_channels,
|
|
311
|
-
activation_config=layers_cfg.ActivationConfig(
|
|
316
|
+
activation_config=layers_cfg.ActivationConfig(
|
|
317
|
+
layers_cfg.ActivationType.SILU
|
|
318
|
+
),
|
|
312
319
|
block_out_channels=block_out_channels,
|
|
313
320
|
scaling_factor=scaling_factor,
|
|
314
321
|
layers_per_block=layers_per_block,
|
|
@@ -13,14 +13,13 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import torch
|
|
17
|
-
from torch import nn
|
|
18
|
-
|
|
19
16
|
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
20
17
|
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
18
|
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
19
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
20
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
21
|
+
import torch
|
|
22
|
+
from torch import nn
|
|
24
23
|
|
|
25
24
|
_down_encoder_blocks_tensor_names = [
|
|
26
25
|
stable_diffusion_loader.DownEncoderBlockTensorNames(
|
|
@@ -39,9 +38,15 @@ _down_encoder_blocks_tensor_names = [
|
|
|
39
38
|
],
|
|
40
39
|
transformer_block_tensor_names=[
|
|
41
40
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
42
|
-
pre_conv_norm=
|
|
43
|
-
|
|
44
|
-
|
|
41
|
+
pre_conv_norm=(
|
|
42
|
+
f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm"
|
|
43
|
+
),
|
|
44
|
+
conv_in=(
|
|
45
|
+
f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in"
|
|
46
|
+
),
|
|
47
|
+
conv_out=(
|
|
48
|
+
f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out"
|
|
49
|
+
),
|
|
45
50
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
46
51
|
norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
|
|
47
52
|
q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
|
|
@@ -80,7 +85,9 @@ _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
|
|
|
80
85
|
conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
|
|
81
86
|
norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
|
|
82
87
|
conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
|
|
83
|
-
time_embedding=
|
|
88
|
+
time_embedding=(
|
|
89
|
+
f"model.diffusion_model.middle_block.{i}.emb_layers.1"
|
|
90
|
+
),
|
|
84
91
|
)
|
|
85
92
|
for i in [0, 2]
|
|
86
93
|
],
|
|
@@ -117,8 +124,12 @@ _up_decoder_blocks_tensor_names = [
|
|
|
117
124
|
stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
|
|
118
125
|
residual_block_tensor_names=[
|
|
119
126
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
120
|
-
norm_1=
|
|
121
|
-
|
|
127
|
+
norm_1=(
|
|
128
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0"
|
|
129
|
+
),
|
|
130
|
+
conv_1=(
|
|
131
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2"
|
|
132
|
+
),
|
|
122
133
|
norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
|
|
123
134
|
conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
|
|
124
135
|
time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
|
|
@@ -128,9 +139,15 @@ _up_decoder_blocks_tensor_names = [
|
|
|
128
139
|
],
|
|
129
140
|
transformer_block_tensor_names=[
|
|
130
141
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
131
|
-
pre_conv_norm=
|
|
132
|
-
|
|
133
|
-
|
|
142
|
+
pre_conv_norm=(
|
|
143
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.1.norm"
|
|
144
|
+
),
|
|
145
|
+
conv_in=(
|
|
146
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in"
|
|
147
|
+
),
|
|
148
|
+
conv_out=(
|
|
149
|
+
f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out"
|
|
150
|
+
),
|
|
134
151
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
135
152
|
norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
|
|
136
153
|
q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
|
|
@@ -157,7 +174,9 @@ _up_decoder_blocks_tensor_names = [
|
|
|
157
174
|
else None,
|
|
158
175
|
upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
|
|
159
176
|
if 0 < i < 3
|
|
160
|
-
else (
|
|
177
|
+
else (
|
|
178
|
+
f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None
|
|
179
|
+
),
|
|
161
180
|
)
|
|
162
181
|
for i in range(4)
|
|
163
182
|
]
|
|
@@ -475,7 +494,10 @@ class Diffusion(nn.Module):
|
|
|
475
494
|
layers_cfg.ActivationConfig(config.final_activation_type)
|
|
476
495
|
)
|
|
477
496
|
self.conv_out = nn.Conv2d(
|
|
478
|
-
reversed_block_out_channels[-1],
|
|
497
|
+
reversed_block_out_channels[-1],
|
|
498
|
+
config.out_channels,
|
|
499
|
+
kernel_size=3,
|
|
500
|
+
padding=1,
|
|
479
501
|
)
|
|
480
502
|
|
|
481
503
|
@torch.inference_mode
|
|
@@ -496,12 +518,15 @@ class Diffusion(nn.Module):
|
|
|
496
518
|
x = self.conv_in(latents)
|
|
497
519
|
skip_connection_tensors = [x]
|
|
498
520
|
for encoder in self.down_encoders:
|
|
499
|
-
x, hidden_states = encoder(
|
|
521
|
+
x, hidden_states = encoder(
|
|
522
|
+
x, time_emb, context, output_hidden_states=True
|
|
523
|
+
)
|
|
500
524
|
skip_connection_tensors.extend(hidden_states)
|
|
501
525
|
x = self.mid_block(x, time_emb, context)
|
|
502
526
|
for decoder in self.up_decoders:
|
|
503
527
|
encoder_tensors = [
|
|
504
|
-
skip_connection_tensors.pop()
|
|
528
|
+
skip_connection_tensors.pop()
|
|
529
|
+
for i in range(self.config.layers_per_block + 1)
|
|
505
530
|
]
|
|
506
531
|
x = decoder(x, encoder_tensors, time_emb, context)
|
|
507
532
|
x = self.final_norm(x)
|
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
+
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
16
17
|
import torch
|
|
17
18
|
from torch import nn
|
|
18
19
|
from torch.nn import functional as F
|
|
19
20
|
|
|
20
|
-
from ai_edge_torch.generative.examples.stable_diffusion.attention import SelfAttention # NOQA
|
|
21
|
-
|
|
22
21
|
|
|
23
22
|
class AttentionBlock(nn.Module):
|
|
24
23
|
|
|
@@ -50,7 +49,9 @@ class ResidualBlock(nn.Module):
|
|
|
50
49
|
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
|
51
50
|
|
|
52
51
|
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
|
|
53
|
-
self.conv_2 = nn.Conv2d(
|
|
52
|
+
self.conv_2 = nn.Conv2d(
|
|
53
|
+
out_channels, out_channels, kernel_size=3, padding=1
|
|
54
|
+
)
|
|
54
55
|
|
|
55
56
|
if in_channels == out_channels:
|
|
56
57
|
self.residual_layer = nn.Identity()
|
|
@@ -18,30 +18,41 @@ import os
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from typing import Dict, Optional
|
|
20
20
|
|
|
21
|
-
import numpy as np
|
|
22
|
-
from PIL import Image
|
|
23
|
-
from tqdm import tqdm
|
|
24
|
-
|
|
25
21
|
import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
|
|
26
22
|
from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
|
|
27
23
|
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
28
24
|
from ai_edge_torch.model import TfLiteModel
|
|
25
|
+
import numpy as np
|
|
26
|
+
from PIL import Image
|
|
27
|
+
from tqdm import tqdm
|
|
29
28
|
|
|
30
29
|
arg_parser = argparse.ArgumentParser()
|
|
31
30
|
arg_parser.add_argument(
|
|
32
31
|
'--tokenizer_vocab_dir',
|
|
33
32
|
type=str,
|
|
34
|
-
help=
|
|
33
|
+
help=(
|
|
34
|
+
'Directory to the tokenizer vocabulary files, which include'
|
|
35
|
+
' `merges.txt` and `vocab.json`'
|
|
36
|
+
),
|
|
35
37
|
required=True,
|
|
36
38
|
)
|
|
37
39
|
arg_parser.add_argument(
|
|
38
|
-
'--clip_ckpt',
|
|
40
|
+
'--clip_ckpt',
|
|
41
|
+
type=str,
|
|
42
|
+
help='Path to CLIP TFLite tflite file',
|
|
43
|
+
required=True,
|
|
39
44
|
)
|
|
40
45
|
arg_parser.add_argument(
|
|
41
|
-
'--diffusion_ckpt',
|
|
46
|
+
'--diffusion_ckpt',
|
|
47
|
+
type=str,
|
|
48
|
+
help='Path to diffusion tflite file',
|
|
49
|
+
required=True,
|
|
42
50
|
)
|
|
43
51
|
arg_parser.add_argument(
|
|
44
|
-
'--decoder_ckpt',
|
|
52
|
+
'--decoder_ckpt',
|
|
53
|
+
type=str,
|
|
54
|
+
help='Path to decoder tflite file',
|
|
55
|
+
required=True,
|
|
45
56
|
)
|
|
46
57
|
arg_parser.add_argument(
|
|
47
58
|
'--output_path',
|
|
@@ -56,20 +67,29 @@ arg_parser.add_argument(
|
|
|
56
67
|
help='The prompt to guide the image generation.',
|
|
57
68
|
)
|
|
58
69
|
arg_parser.add_argument(
|
|
59
|
-
'--n_inference_steps',
|
|
70
|
+
'--n_inference_steps',
|
|
71
|
+
default=20,
|
|
72
|
+
type=int,
|
|
73
|
+
help='The number of denoising steps.',
|
|
60
74
|
)
|
|
61
75
|
arg_parser.add_argument(
|
|
62
76
|
'--sampler',
|
|
63
77
|
default='k_euler',
|
|
64
78
|
type=str,
|
|
65
79
|
choices=['k_euler', 'k_euler_ancestral', 'k_lms'],
|
|
66
|
-
help=
|
|
80
|
+
help=(
|
|
81
|
+
'A sampler to be used to denoise the encoded image latents. Can be one'
|
|
82
|
+
' of `k_lms, `k_euler`, or `k_euler_ancestral`.'
|
|
83
|
+
),
|
|
67
84
|
)
|
|
68
85
|
arg_parser.add_argument(
|
|
69
86
|
'--seed',
|
|
70
87
|
default=None,
|
|
71
88
|
type=int,
|
|
72
|
-
help=
|
|
89
|
+
help=(
|
|
90
|
+
'A seed to make generation deterministic. A random number is used if'
|
|
91
|
+
' unspecified.'
|
|
92
|
+
),
|
|
73
93
|
)
|
|
74
94
|
|
|
75
95
|
|
|
@@ -154,7 +174,9 @@ def run_tflite_pipeline(
|
|
|
154
174
|
elif sampler == 'k_euler':
|
|
155
175
|
sampler = samplers.KEulerSampler(n_inference_steps=n_inference_steps)
|
|
156
176
|
elif sampler == 'k_euler_ancestral':
|
|
157
|
-
sampler = samplers.KEulerAncestralSampler(
|
|
177
|
+
sampler = samplers.KEulerAncestralSampler(
|
|
178
|
+
n_inference_steps=n_inference_steps
|
|
179
|
+
)
|
|
158
180
|
else:
|
|
159
181
|
raise ValueError(
|
|
160
182
|
'Unknown sampler value %s. '
|
|
@@ -173,7 +195,8 @@ def run_tflite_pipeline(
|
|
|
173
195
|
if input_image:
|
|
174
196
|
if not hasattr(model, 'encoder'):
|
|
175
197
|
raise AttributeError(
|
|
176
|
-
'Stable Diffusion must be initialized with encoder to accept
|
|
198
|
+
'Stable Diffusion must be initialized with encoder to accept'
|
|
199
|
+
' input_image.'
|
|
177
200
|
)
|
|
178
201
|
input_image = input_image.resize((width, height))
|
|
179
202
|
input_image_np = np.array(input_image).astype(np.float32)
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
19
17
|
from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
|
|
18
|
+
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class KEulerSampler(SamplerInterface):
|
|
@@ -46,7 +45,9 @@ class KEulerSampler(SamplerInterface):
|
|
|
46
45
|
|
|
47
46
|
def set_strength(self, strength=1):
|
|
48
47
|
start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
|
|
49
|
-
self.timesteps = np.linspace(
|
|
48
|
+
self.timesteps = np.linspace(
|
|
49
|
+
self.n_training_steps - 1, 0, self.n_inference_steps
|
|
50
|
+
)
|
|
50
51
|
self.timesteps = self.timesteps[start_step:]
|
|
51
52
|
self.initial_scale = self.sigmas[start_step]
|
|
52
53
|
self.step_count = start_step
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
19
17
|
from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
|
|
18
|
+
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class KEulerAncestralSampler(SamplerInterface):
|
|
@@ -46,7 +45,9 @@ class KEulerAncestralSampler(SamplerInterface):
|
|
|
46
45
|
|
|
47
46
|
def set_strength(self, strength=1):
|
|
48
47
|
start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
|
|
49
|
-
self.timesteps = np.linspace(
|
|
48
|
+
self.timesteps = np.linspace(
|
|
49
|
+
self.n_training_steps - 1, 0, self.n_inference_steps
|
|
50
|
+
)
|
|
50
51
|
self.timesteps = self.timesteps[start_step:]
|
|
51
52
|
self.initial_scale = self.sigmas[start_step]
|
|
52
53
|
self.step_count = start_step
|
|
@@ -13,10 +13,9 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
|
-
import numpy as np
|
|
17
|
-
|
|
18
16
|
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
19
17
|
from ai_edge_torch.generative.examples.stable_diffusion.samplers.sampler import SamplerInterface # NOQA
|
|
18
|
+
import numpy as np
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
class KLMSSampler(SamplerInterface):
|
|
@@ -48,7 +47,9 @@ class KLMSSampler(SamplerInterface):
|
|
|
48
47
|
|
|
49
48
|
def set_strength(self, strength=1):
|
|
50
49
|
start_step = self.n_inference_steps - int(self.n_inference_steps * strength)
|
|
51
|
-
self.timesteps = np.linspace(
|
|
50
|
+
self.timesteps = np.linspace(
|
|
51
|
+
self.n_training_steps - 1, 0, self.n_inference_steps
|
|
52
|
+
)
|
|
52
53
|
self.timesteps = self.timesteps[start_step:]
|
|
53
54
|
self.initial_scale = self.sigmas[start_step]
|
|
54
55
|
self.step_count = start_step
|
|
@@ -27,7 +27,10 @@ def create_bytes_table() -> dict:
|
|
|
27
27
|
special_count = 0
|
|
28
28
|
for byte in range(256):
|
|
29
29
|
category = unicodedata.category(chr(byte))
|
|
30
|
-
if category[0] not in [
|
|
30
|
+
if category[0] not in [
|
|
31
|
+
'C',
|
|
32
|
+
'Z',
|
|
33
|
+
]: # ith character is NOT control char or space
|
|
31
34
|
table[byte] = chr(byte)
|
|
32
35
|
else: # ith character IS control char or space
|
|
33
36
|
table[byte] = chr(special_count + 256)
|
|
@@ -20,14 +20,20 @@ import torch
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def get_time_embedding(timestep):
|
|
23
|
-
freqs = torch.pow(
|
|
23
|
+
freqs = torch.pow(
|
|
24
|
+
10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160
|
|
25
|
+
)
|
|
24
26
|
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
|
25
27
|
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
|
26
28
|
|
|
27
29
|
|
|
28
|
-
def get_alphas_cumprod(
|
|
30
|
+
def get_alphas_cumprod(
|
|
31
|
+
beta_start=0.00085, beta_end=0.0120, n_training_steps=1000
|
|
32
|
+
):
|
|
29
33
|
betas = (
|
|
30
|
-
np.linspace(
|
|
34
|
+
np.linspace(
|
|
35
|
+
beta_start**0.5, beta_end**0.5, n_training_steps, dtype=np.float32
|
|
36
|
+
)
|
|
31
37
|
** 2
|
|
32
38
|
)
|
|
33
39
|
alphas = 1.0 - betas
|
|
@@ -16,12 +16,11 @@
|
|
|
16
16
|
import os
|
|
17
17
|
from pathlib import Path
|
|
18
18
|
|
|
19
|
-
import numpy as np
|
|
20
|
-
import torch
|
|
21
|
-
|
|
22
19
|
import ai_edge_torch
|
|
23
20
|
from ai_edge_torch.generative.examples.t5 import t5
|
|
24
21
|
from ai_edge_torch.generative.quantize import quant_recipes
|
|
22
|
+
import numpy as np
|
|
23
|
+
import torch
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
# TODO(haoliang): clean this up untile 2-sig model is validated e2e.
|
|
@@ -73,8 +72,12 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
|
|
|
73
72
|
embedding_layer = torch.nn.Embedding(
|
|
74
73
|
config.vocab_size, config.embedding_dim, padding_idx=0
|
|
75
74
|
)
|
|
76
|
-
t5_encoder_model = t5.build_t5_encoder_model(
|
|
77
|
-
|
|
75
|
+
t5_encoder_model = t5.build_t5_encoder_model(
|
|
76
|
+
config, embedding_layer, checkpoint_path
|
|
77
|
+
)
|
|
78
|
+
t5_decoder_model = t5.build_t5_decoder_model(
|
|
79
|
+
config, embedding_layer, checkpoint_path
|
|
80
|
+
)
|
|
78
81
|
|
|
79
82
|
# encoder
|
|
80
83
|
seq_len = 512
|