ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +31 -0
- ai_edge_torch/convert/__init__.py +14 -0
- ai_edge_torch/convert/conversion.py +117 -0
- ai_edge_torch/convert/conversion_utils.py +400 -0
- ai_edge_torch/convert/converter.py +202 -0
- ai_edge_torch/convert/fx_passes/__init__.py +59 -0
- ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
- ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
- ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
- ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
- ai_edge_torch/convert/test/__init__.py +14 -0
- ai_edge_torch/convert/test/test_convert.py +311 -0
- ai_edge_torch/convert/test/test_convert_composites.py +192 -0
- ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
- ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
- ai_edge_torch/convert/to_channel_last_io.py +85 -0
- ai_edge_torch/debug/__init__.py +17 -0
- ai_edge_torch/debug/culprit.py +464 -0
- ai_edge_torch/debug/test/__init__.py +14 -0
- ai_edge_torch/debug/test/test_culprit.py +133 -0
- ai_edge_torch/debug/test/test_search_model.py +50 -0
- ai_edge_torch/debug/utils.py +48 -0
- ai_edge_torch/experimental/__init__.py +14 -0
- ai_edge_torch/generative/__init__.py +14 -0
- ai_edge_torch/generative/examples/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
- ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
- ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
- ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
- ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
- ai_edge_torch/generative/examples/t5/__init__.py +14 -0
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
- ai_edge_torch/generative/examples/t5/t5.py +608 -0
- ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
- ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
- ai_edge_torch/generative/fx_passes/__init__.py +31 -0
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
- ai_edge_torch/generative/layers/__init__.py +14 -0
- ai_edge_torch/generative/layers/attention.py +354 -0
- ai_edge_torch/generative/layers/attention_utils.py +169 -0
- ai_edge_torch/generative/layers/builder.py +131 -0
- ai_edge_torch/generative/layers/feed_forward.py +95 -0
- ai_edge_torch/generative/layers/kv_cache.py +83 -0
- ai_edge_torch/generative/layers/model_config.py +158 -0
- ai_edge_torch/generative/layers/normalization.py +62 -0
- ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/layers/unet/__init__.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
- ai_edge_torch/generative/layers/unet/builder.py +47 -0
- ai_edge_torch/generative/layers/unet/model_config.py +269 -0
- ai_edge_torch/generative/quantize/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
- ai_edge_torch/generative/quantize/example.py +45 -0
- ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
- ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
- ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
- ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
- ai_edge_torch/generative/test/__init__.py +14 -0
- ai_edge_torch/generative/test/loader_test.py +80 -0
- ai_edge_torch/generative/test/test_model_conversion.py +235 -0
- ai_edge_torch/generative/test/test_quantize.py +162 -0
- ai_edge_torch/generative/utilities/__init__.py +15 -0
- ai_edge_torch/generative/utilities/loader.py +328 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
- ai_edge_torch/generative/utilities/t5_loader.py +483 -0
- ai_edge_torch/hlfb/__init__.py +16 -0
- ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
- ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
- ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
- ai_edge_torch/hlfb/test/__init__.py +14 -0
- ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
- ai_edge_torch/model.py +142 -0
- ai_edge_torch/quantize/__init__.py +16 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
- ai_edge_torch/quantize/quant_config.py +81 -0
- ai_edge_torch/testing/__init__.py +14 -0
- ai_edge_torch/testing/model_coverage/__init__.py +16 -0
- ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
from ai_edge_torch.generative.layers.attention import TransformerBlock
|
|
20
|
+
import ai_edge_torch.generative.layers.attention_utils as attention_utils
|
|
21
|
+
import ai_edge_torch.generative.layers.builder as builder
|
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
|
23
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
|
24
|
+
|
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
26
|
+
ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
|
|
27
|
+
ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
|
|
28
|
+
attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
|
|
29
|
+
attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
|
|
30
|
+
attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
|
|
31
|
+
attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
|
|
32
|
+
pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
|
|
33
|
+
pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
|
|
34
|
+
embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
|
|
35
|
+
embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
|
36
|
+
final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
|
|
37
|
+
lm_head=None,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CLIP(nn.Module):
|
|
42
|
+
"""CLIP text encoder
|
|
43
|
+
For details, see https://arxiv.org/abs/2103.00020
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, config: cfg.ModelConfig):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
|
49
|
+
self.tok_embedding_position = nn.Parameter(
|
|
50
|
+
torch.zeros((config.max_seq_len, config.embedding_dim))
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self.config = config
|
|
54
|
+
self.transformer_blocks = nn.ModuleList(
|
|
55
|
+
TransformerBlock(config) for _ in range(config.num_layers)
|
|
56
|
+
)
|
|
57
|
+
self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
|
|
58
|
+
|
|
59
|
+
self.mask_cache = attention_utils.build_causal_mask_cache(
|
|
60
|
+
size=config.max_seq_len, dtype=torch.float32
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@torch.inference_mode
|
|
64
|
+
def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
|
65
|
+
tokens = tokens.type(torch.long)
|
|
66
|
+
|
|
67
|
+
state = self.tok_embedding(tokens) + self.tok_embedding_position
|
|
68
|
+
for layer in self.transformer_blocks:
|
|
69
|
+
state = layer(state, mask=self.mask_cache)
|
|
70
|
+
output = self.final_norm(state)
|
|
71
|
+
return output
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_model_config() -> cfg.ModelConfig:
|
|
75
|
+
max_seq_len = 77
|
|
76
|
+
vocab_size = 49408
|
|
77
|
+
num_layers = 12
|
|
78
|
+
num_heads = 12
|
|
79
|
+
num_query_groups = 12
|
|
80
|
+
embedding_dim = 768
|
|
81
|
+
|
|
82
|
+
attn_config = cfg.AttentionConfig(
|
|
83
|
+
num_heads=num_heads,
|
|
84
|
+
num_query_groups=num_query_groups,
|
|
85
|
+
rotary_percentage=0.0,
|
|
86
|
+
qkv_use_bias=True,
|
|
87
|
+
qkv_transpose_before_split=True,
|
|
88
|
+
qkv_fused_interleaved=False,
|
|
89
|
+
output_proj_use_bias=True,
|
|
90
|
+
enable_kv_cache=False,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
ff_config = cfg.FeedForwardConfig(
|
|
94
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
95
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
|
|
96
|
+
intermediate_size=embedding_dim * 4,
|
|
97
|
+
use_bias=True,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
|
101
|
+
|
|
102
|
+
config = cfg.ModelConfig(
|
|
103
|
+
vocab_size=vocab_size,
|
|
104
|
+
num_layers=num_layers,
|
|
105
|
+
max_seq_len=max_seq_len,
|
|
106
|
+
embedding_dim=embedding_dim,
|
|
107
|
+
attn_config=attn_config,
|
|
108
|
+
ff_config=ff_config,
|
|
109
|
+
pre_attention_norm_config=norm_config,
|
|
110
|
+
pre_ff_norm_config=norm_config,
|
|
111
|
+
final_norm_config=norm_config,
|
|
112
|
+
enable_hlfb=True,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
return config
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import argparse
|
|
17
|
+
import os
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
|
|
23
|
+
import ai_edge_torch
|
|
24
|
+
import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
|
|
25
|
+
import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
|
|
26
|
+
import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
|
|
27
|
+
from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
|
|
28
|
+
import ai_edge_torch.generative.examples.stable_diffusion.util as util
|
|
29
|
+
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
30
|
+
|
|
31
|
+
arg_parser = argparse.ArgumentParser()
|
|
32
|
+
arg_parser.add_argument(
|
|
33
|
+
'--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
|
|
34
|
+
)
|
|
35
|
+
arg_parser.add_argument(
|
|
36
|
+
'--diffusion_ckpt',
|
|
37
|
+
type=str,
|
|
38
|
+
help='Path to source diffusion model checkpoint',
|
|
39
|
+
required=True,
|
|
40
|
+
)
|
|
41
|
+
arg_parser.add_argument(
|
|
42
|
+
'--decoder_ckpt',
|
|
43
|
+
type=str,
|
|
44
|
+
help='Path to source image decoder model checkpoint',
|
|
45
|
+
required=True,
|
|
46
|
+
)
|
|
47
|
+
arg_parser.add_argument(
|
|
48
|
+
'--output_dir',
|
|
49
|
+
type=str,
|
|
50
|
+
help='Path to the converted TF Lite directory.',
|
|
51
|
+
required=True,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@torch.inference_mode
|
|
56
|
+
def convert_stable_diffusion_to_tflite(
|
|
57
|
+
output_dir: str,
|
|
58
|
+
clip_ckpt_path: str,
|
|
59
|
+
diffusion_ckpt_path: str,
|
|
60
|
+
decoder_ckpt_path: str,
|
|
61
|
+
image_height: int = 512,
|
|
62
|
+
image_width: int = 512,
|
|
63
|
+
):
|
|
64
|
+
|
|
65
|
+
clip_model = clip.CLIP(clip.get_model_config())
|
|
66
|
+
loader = stable_diffusion_loader.ClipModelLoader(
|
|
67
|
+
clip_ckpt_path,
|
|
68
|
+
clip.TENSOR_NAMES,
|
|
69
|
+
)
|
|
70
|
+
loader.load(clip_model, strict=False)
|
|
71
|
+
|
|
72
|
+
diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
|
|
73
|
+
diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
|
|
74
|
+
diffusion_ckpt_path, diffusion.TENSOR_NAMES
|
|
75
|
+
)
|
|
76
|
+
diffusion_loader.load(diffusion_model, strict=False)
|
|
77
|
+
|
|
78
|
+
decoder_model = decoder.Decoder(decoder.get_model_config())
|
|
79
|
+
decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
|
|
80
|
+
decoder_ckpt_path, decoder.TENSOR_NAMES
|
|
81
|
+
)
|
|
82
|
+
decoder_loader.load(decoder_model, strict=False)
|
|
83
|
+
|
|
84
|
+
# TODO(yichunk): enable image encoder conversion
|
|
85
|
+
# if encoder_ckpt_path is not None:
|
|
86
|
+
# encoder = Encoder()
|
|
87
|
+
# encoder.load_state_dict(torch.load(encoder_ckpt_path))
|
|
88
|
+
|
|
89
|
+
# Tensors used to trace the model graph during conversion.
|
|
90
|
+
n_tokens = 77
|
|
91
|
+
timestamp = 0
|
|
92
|
+
len_prompt = 1
|
|
93
|
+
prompt_tokens = torch.full((1, n_tokens), 0, dtype=torch.long)
|
|
94
|
+
input_image = torch.full((1, 3, image_height, image_width), 0, dtype=torch.float32)
|
|
95
|
+
noise = torch.full(
|
|
96
|
+
(len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
input_latents = torch.zeros_like(noise)
|
|
100
|
+
context_cond = clip_model(prompt_tokens)
|
|
101
|
+
context_uncond = torch.zeros_like(context_cond)
|
|
102
|
+
context = torch.cat([context_cond, context_uncond], axis=0)
|
|
103
|
+
time_embedding = util.get_time_embedding(timestamp)
|
|
104
|
+
|
|
105
|
+
if not os.path.exists(output_dir):
|
|
106
|
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
|
107
|
+
|
|
108
|
+
# TODO(yichunk): convert to multi signature tflite model.
|
|
109
|
+
# CLIP text encoder
|
|
110
|
+
ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
|
|
111
|
+
f'{output_dir}/clip.tflite'
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# TODO(yichunk): enable image encoder conversion
|
|
115
|
+
# Image encoder
|
|
116
|
+
# ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
|
|
117
|
+
# f'{output_dir}/encoder.tflite'
|
|
118
|
+
# )
|
|
119
|
+
|
|
120
|
+
# Diffusion
|
|
121
|
+
ai_edge_torch.signature(
|
|
122
|
+
'diffusion',
|
|
123
|
+
diffusion_model,
|
|
124
|
+
(torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
|
|
125
|
+
).convert().export(f'{output_dir}/diffusion.tflite')
|
|
126
|
+
|
|
127
|
+
# Image decoder
|
|
128
|
+
ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
|
|
129
|
+
f'{output_dir}/decoder.tflite'
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if __name__ == '__main__':
|
|
134
|
+
args = arg_parser.parse_args()
|
|
135
|
+
convert_stable_diffusion_to_tflite(
|
|
136
|
+
output_dir=args.output_dir,
|
|
137
|
+
clip_ckpt_path=args.clip_ckpt,
|
|
138
|
+
diffusion_ckpt_path=args.diffusion_ckpt,
|
|
139
|
+
decoder_ckpt_path=args.decoder_ckpt,
|
|
140
|
+
image_height=512,
|
|
141
|
+
image_width=512,
|
|
142
|
+
)
|
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
from torch import nn
|
|
18
|
+
|
|
19
|
+
import ai_edge_torch.generative.layers.builder as layers_builder
|
|
20
|
+
import ai_edge_torch.generative.layers.model_config as layers_cfg
|
|
21
|
+
import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
22
|
+
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
|
+
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
24
|
+
|
|
25
|
+
TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
|
+
post_quant_conv="first_stage_model.post_quant_conv",
|
|
27
|
+
conv_in="first_stage_model.decoder.conv_in",
|
|
28
|
+
mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
|
|
29
|
+
residual_block_tensor_names=[
|
|
30
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
31
|
+
norm_1="first_stage_model.decoder.mid.block_1.norm1",
|
|
32
|
+
norm_2="first_stage_model.decoder.mid.block_1.norm2",
|
|
33
|
+
conv_1="first_stage_model.decoder.mid.block_1.conv1",
|
|
34
|
+
conv_2="first_stage_model.decoder.mid.block_1.conv2",
|
|
35
|
+
),
|
|
36
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
37
|
+
norm_1="first_stage_model.decoder.mid.block_2.norm1",
|
|
38
|
+
norm_2="first_stage_model.decoder.mid.block_2.norm2",
|
|
39
|
+
conv_1="first_stage_model.decoder.mid.block_2.conv1",
|
|
40
|
+
conv_2="first_stage_model.decoder.mid.block_2.conv2",
|
|
41
|
+
),
|
|
42
|
+
],
|
|
43
|
+
attention_block_tensor_names=[
|
|
44
|
+
stable_diffusion_loader.AttentionBlockTensorNames(
|
|
45
|
+
norm="first_stage_model.decoder.mid.attn_1.norm",
|
|
46
|
+
q_proj="first_stage_model.decoder.mid.attn_1.q",
|
|
47
|
+
k_proj="first_stage_model.decoder.mid.attn_1.k",
|
|
48
|
+
v_proj="first_stage_model.decoder.mid.attn_1.v",
|
|
49
|
+
output_proj="first_stage_model.decoder.mid.attn_1.proj_out",
|
|
50
|
+
)
|
|
51
|
+
],
|
|
52
|
+
),
|
|
53
|
+
up_decoder_blocks_tensor_names=[
|
|
54
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
55
|
+
residual_block_tensor_names=[
|
|
56
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
57
|
+
norm_1="first_stage_model.decoder.up.3.block.0.norm1",
|
|
58
|
+
norm_2="first_stage_model.decoder.up.3.block.0.norm2",
|
|
59
|
+
conv_1="first_stage_model.decoder.up.3.block.0.conv1",
|
|
60
|
+
conv_2="first_stage_model.decoder.up.3.block.0.conv2",
|
|
61
|
+
),
|
|
62
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
63
|
+
norm_1="first_stage_model.decoder.up.3.block.1.norm1",
|
|
64
|
+
norm_2="first_stage_model.decoder.up.3.block.1.norm2",
|
|
65
|
+
conv_1="first_stage_model.decoder.up.3.block.1.conv1",
|
|
66
|
+
conv_2="first_stage_model.decoder.up.3.block.1.conv2",
|
|
67
|
+
),
|
|
68
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
69
|
+
norm_1="first_stage_model.decoder.up.3.block.2.norm1",
|
|
70
|
+
norm_2="first_stage_model.decoder.up.3.block.2.norm2",
|
|
71
|
+
conv_1="first_stage_model.decoder.up.3.block.2.conv1",
|
|
72
|
+
conv_2="first_stage_model.decoder.up.3.block.2.conv2",
|
|
73
|
+
),
|
|
74
|
+
],
|
|
75
|
+
upsample_conv="first_stage_model.decoder.up.3.upsample.conv",
|
|
76
|
+
),
|
|
77
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
78
|
+
residual_block_tensor_names=[
|
|
79
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
80
|
+
norm_1="first_stage_model.decoder.up.2.block.0.norm1",
|
|
81
|
+
norm_2="first_stage_model.decoder.up.2.block.0.norm2",
|
|
82
|
+
conv_1="first_stage_model.decoder.up.2.block.0.conv1",
|
|
83
|
+
conv_2="first_stage_model.decoder.up.2.block.0.conv2",
|
|
84
|
+
),
|
|
85
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
86
|
+
norm_1="first_stage_model.decoder.up.2.block.1.norm1",
|
|
87
|
+
norm_2="first_stage_model.decoder.up.2.block.1.norm2",
|
|
88
|
+
conv_1="first_stage_model.decoder.up.2.block.1.conv1",
|
|
89
|
+
conv_2="first_stage_model.decoder.up.2.block.1.conv2",
|
|
90
|
+
),
|
|
91
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
92
|
+
norm_1="first_stage_model.decoder.up.2.block.2.norm1",
|
|
93
|
+
norm_2="first_stage_model.decoder.up.2.block.2.norm2",
|
|
94
|
+
conv_1="first_stage_model.decoder.up.2.block.2.conv1",
|
|
95
|
+
conv_2="first_stage_model.decoder.up.2.block.2.conv2",
|
|
96
|
+
),
|
|
97
|
+
],
|
|
98
|
+
upsample_conv="first_stage_model.decoder.up.2.upsample.conv",
|
|
99
|
+
),
|
|
100
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
101
|
+
residual_block_tensor_names=[
|
|
102
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
103
|
+
norm_1="first_stage_model.decoder.up.1.block.0.norm1",
|
|
104
|
+
norm_2="first_stage_model.decoder.up.1.block.0.norm2",
|
|
105
|
+
conv_1="first_stage_model.decoder.up.1.block.0.conv1",
|
|
106
|
+
conv_2="first_stage_model.decoder.up.1.block.0.conv2",
|
|
107
|
+
residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
|
|
108
|
+
),
|
|
109
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
110
|
+
norm_1="first_stage_model.decoder.up.1.block.1.norm1",
|
|
111
|
+
norm_2="first_stage_model.decoder.up.1.block.1.norm2",
|
|
112
|
+
conv_1="first_stage_model.decoder.up.1.block.1.conv1",
|
|
113
|
+
conv_2="first_stage_model.decoder.up.1.block.1.conv2",
|
|
114
|
+
),
|
|
115
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
116
|
+
norm_1="first_stage_model.decoder.up.1.block.2.norm1",
|
|
117
|
+
norm_2="first_stage_model.decoder.up.1.block.2.norm2",
|
|
118
|
+
conv_1="first_stage_model.decoder.up.1.block.2.conv1",
|
|
119
|
+
conv_2="first_stage_model.decoder.up.1.block.2.conv2",
|
|
120
|
+
),
|
|
121
|
+
],
|
|
122
|
+
upsample_conv="first_stage_model.decoder.up.1.upsample.conv",
|
|
123
|
+
),
|
|
124
|
+
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
125
|
+
residual_block_tensor_names=[
|
|
126
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
127
|
+
norm_1="first_stage_model.decoder.up.0.block.0.norm1",
|
|
128
|
+
norm_2="first_stage_model.decoder.up.0.block.0.norm2",
|
|
129
|
+
conv_1="first_stage_model.decoder.up.0.block.0.conv1",
|
|
130
|
+
conv_2="first_stage_model.decoder.up.0.block.0.conv2",
|
|
131
|
+
residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
|
|
132
|
+
),
|
|
133
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
134
|
+
norm_1="first_stage_model.decoder.up.0.block.1.norm1",
|
|
135
|
+
norm_2="first_stage_model.decoder.up.0.block.1.norm2",
|
|
136
|
+
conv_1="first_stage_model.decoder.up.0.block.1.conv1",
|
|
137
|
+
conv_2="first_stage_model.decoder.up.0.block.1.conv2",
|
|
138
|
+
),
|
|
139
|
+
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
140
|
+
norm_1="first_stage_model.decoder.up.0.block.2.norm1",
|
|
141
|
+
norm_2="first_stage_model.decoder.up.0.block.2.norm2",
|
|
142
|
+
conv_1="first_stage_model.decoder.up.0.block.2.conv1",
|
|
143
|
+
conv_2="first_stage_model.decoder.up.0.block.2.conv2",
|
|
144
|
+
),
|
|
145
|
+
],
|
|
146
|
+
),
|
|
147
|
+
],
|
|
148
|
+
final_norm="first_stage_model.decoder.norm_out",
|
|
149
|
+
conv_out="first_stage_model.decoder.conv_out",
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class Decoder(nn.Module):
|
|
154
|
+
"""The Decoder model used in Stable Diffusion.
|
|
155
|
+
|
|
156
|
+
For details, see https://arxiv.org/abs/2103.00020
|
|
157
|
+
|
|
158
|
+
Sturcture of the Decoder:
|
|
159
|
+
|
|
160
|
+
latents tensor
|
|
161
|
+
|
|
|
162
|
+
▼
|
|
163
|
+
┌───────────────────┐
|
|
164
|
+
│ Post Quant Conv │
|
|
165
|
+
└─────────┬─────────┘
|
|
166
|
+
│
|
|
167
|
+
┌─────────▼─────────┐
|
|
168
|
+
│ ConvIn │
|
|
169
|
+
└─────────┬─────────┘
|
|
170
|
+
│
|
|
171
|
+
┌─────────▼─────────┐
|
|
172
|
+
│ MidBlock2D │
|
|
173
|
+
└─────────┬─────────┘
|
|
174
|
+
│
|
|
175
|
+
┌─────────▼─────────┐
|
|
176
|
+
│ UpDecoder2D │ x 4
|
|
177
|
+
└─────────┬─────────┘
|
|
178
|
+
│
|
|
179
|
+
┌─────────▼─────────┐
|
|
180
|
+
│ FinalNorm │
|
|
181
|
+
└─────────┬─────────┘
|
|
182
|
+
|
|
|
183
|
+
┌─────────▼─────────┐
|
|
184
|
+
│ Activation │
|
|
185
|
+
└─────────┬─────────┘
|
|
186
|
+
|
|
|
187
|
+
┌─────────▼─────────┐
|
|
188
|
+
│ ConvOut │
|
|
189
|
+
└─────────┬─────────┘
|
|
190
|
+
|
|
|
191
|
+
▼
|
|
192
|
+
Output Image
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(self, config: unet_cfg.AutoEncoderConfig):
|
|
196
|
+
super().__init__()
|
|
197
|
+
self.config = config
|
|
198
|
+
self.post_quant_conv = nn.Conv2d(
|
|
199
|
+
config.latent_channels,
|
|
200
|
+
config.latent_channels,
|
|
201
|
+
kernel_size=1,
|
|
202
|
+
stride=1,
|
|
203
|
+
padding=0,
|
|
204
|
+
)
|
|
205
|
+
reversed_block_out_channels = list(reversed(config.block_out_channels))
|
|
206
|
+
self.conv_in = nn.Conv2d(
|
|
207
|
+
config.latent_channels,
|
|
208
|
+
reversed_block_out_channels[0],
|
|
209
|
+
kernel_size=3,
|
|
210
|
+
stride=1,
|
|
211
|
+
padding=1,
|
|
212
|
+
)
|
|
213
|
+
self.mid_block = blocks_2d.MidBlock2D(config.mid_block_config)
|
|
214
|
+
up_decoder_blocks = []
|
|
215
|
+
block_out_channels = reversed_block_out_channels[0]
|
|
216
|
+
for i, out_channels in enumerate(reversed_block_out_channels):
|
|
217
|
+
prev_output_channel = block_out_channels
|
|
218
|
+
block_out_channels = out_channels
|
|
219
|
+
not_final_block = i < len(reversed_block_out_channels) - 1
|
|
220
|
+
up_decoder_blocks.append(
|
|
221
|
+
blocks_2d.UpDecoderBlock2D(
|
|
222
|
+
unet_cfg.UpDecoderBlock2DConfig(
|
|
223
|
+
in_channels=prev_output_channel,
|
|
224
|
+
out_channels=block_out_channels,
|
|
225
|
+
normalization_config=config.normalization_config,
|
|
226
|
+
activation_config=config.activation_config,
|
|
227
|
+
num_layers=config.layers_per_block,
|
|
228
|
+
add_upsample=not_final_block,
|
|
229
|
+
upsample_conv=True,
|
|
230
|
+
sampling_config=unet_cfg.UpSamplingConfig(
|
|
231
|
+
mode=unet_cfg.SamplingType.NEAREST, scale_factor=2
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
self.up_decoder_blocks = nn.ModuleList(up_decoder_blocks)
|
|
237
|
+
self.final_norm = layers_builder.build_norm(
|
|
238
|
+
block_out_channels, config.normalization_config
|
|
239
|
+
)
|
|
240
|
+
self.act_fn = layers_builder.get_activation(config.activation_config)
|
|
241
|
+
self.conv_out = nn.Conv2d(
|
|
242
|
+
block_out_channels,
|
|
243
|
+
config.out_channels,
|
|
244
|
+
kernel_size=3,
|
|
245
|
+
stride=1,
|
|
246
|
+
padding=1,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def forward(self, latents_tensor: torch.Tensor) -> torch.Tensor:
|
|
250
|
+
"""Forward function of decoder model.
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
latents (torch.Tensor): latents space tensor.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
output decoded image tensor from decoder model.
|
|
257
|
+
"""
|
|
258
|
+
x = latents_tensor / self.config.scaling_factor
|
|
259
|
+
x = self.post_quant_conv(x)
|
|
260
|
+
x = self.conv_in(x)
|
|
261
|
+
x = self.mid_block(x)
|
|
262
|
+
for up_decoder_block in self.up_decoder_blocks:
|
|
263
|
+
x = up_decoder_block(x)
|
|
264
|
+
x = self.final_norm(x)
|
|
265
|
+
x = self.act_fn(x)
|
|
266
|
+
x = self.conv_out(x)
|
|
267
|
+
return x
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
271
|
+
"""Get configs for the Decoder of Stable Diffusion v1.5"""
|
|
272
|
+
in_channels = 3
|
|
273
|
+
latent_channels = 4
|
|
274
|
+
out_channels = 3
|
|
275
|
+
block_out_channels = [128, 256, 512, 512]
|
|
276
|
+
scaling_factor = 0.18215
|
|
277
|
+
layers_per_block = 3
|
|
278
|
+
|
|
279
|
+
norm_config = layers_cfg.NormalizationConfig(
|
|
280
|
+
layers_cfg.NormalizationType.GROUP_NORM, group_num=32
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
att_config = unet_cfg.AttentionBlock2DConfig(
|
|
284
|
+
dim=block_out_channels[-1],
|
|
285
|
+
normalization_config=norm_config,
|
|
286
|
+
attention_config=layers_cfg.AttentionConfig(
|
|
287
|
+
num_heads=1,
|
|
288
|
+
num_query_groups=1,
|
|
289
|
+
qkv_use_bias=True,
|
|
290
|
+
output_proj_use_bias=True,
|
|
291
|
+
enable_kv_cache=False,
|
|
292
|
+
qkv_transpose_before_split=True,
|
|
293
|
+
qkv_fused_interleaved=False,
|
|
294
|
+
rotary_percentage=0.0,
|
|
295
|
+
),
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
mid_block_config = unet_cfg.MidBlock2DConfig(
|
|
299
|
+
in_channels=block_out_channels[-1],
|
|
300
|
+
normalization_config=norm_config,
|
|
301
|
+
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
|
|
302
|
+
num_layers=1,
|
|
303
|
+
attention_block_config=att_config,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
config = unet_cfg.AutoEncoderConfig(
|
|
307
|
+
in_channels=in_channels,
|
|
308
|
+
latent_channels=latent_channels,
|
|
309
|
+
out_channels=out_channels,
|
|
310
|
+
activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
|
|
311
|
+
block_out_channels=block_out_channels,
|
|
312
|
+
scaling_factor=scaling_factor,
|
|
313
|
+
layers_per_block=layers_per_block,
|
|
314
|
+
normalization_config=norm_config,
|
|
315
|
+
mid_block_config=mid_block_config,
|
|
316
|
+
)
|
|
317
|
+
return config
|