ai-edge-torch-nightly 0.2.0.dev20240707__py3-none-any.whl → 0.2.0.dev20240713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion.py +2 -4
- ai_edge_torch/convert/conversion_utils.py +61 -3
- ai_edge_torch/convert/converter.py +47 -16
- ai_edge_torch/convert/test/test_convert.py +39 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -10
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +56 -30
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +72 -69
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +80 -72
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +1 -1
- ai_edge_torch/generative/examples/t5/t5_attention.py +6 -1
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +1 -1
- ai_edge_torch/generative/layers/model_config.py +4 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +1 -1
- ai_edge_torch/generative/layers/unet/model_config.py +5 -5
- ai_edge_torch/generative/utilities/loader.py +9 -6
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +74 -10
- ai_edge_torch/model.py +11 -3
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -13
- {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/RECORD +23 -23
- {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240707.dist-info → ai_edge_torch_nightly-0.2.0.dev20240713.dist-info}/top_level.txt +0 -0
|
@@ -22,29 +22,31 @@ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
|
|
|
22
22
|
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
|
|
23
23
|
import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
|
|
24
24
|
|
|
25
|
-
|
|
26
|
-
post_quant_conv="
|
|
27
|
-
conv_in="
|
|
25
|
+
TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
26
|
+
post_quant_conv="first_stage_model.post_quant_conv",
|
|
27
|
+
conv_in="first_stage_model.decoder.conv_in",
|
|
28
28
|
mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
|
|
29
29
|
residual_block_tensor_names=[
|
|
30
30
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
31
|
-
norm_1="
|
|
32
|
-
norm_2="
|
|
33
|
-
conv_1="
|
|
34
|
-
conv_2="
|
|
31
|
+
norm_1="first_stage_model.decoder.mid.block_1.norm1",
|
|
32
|
+
norm_2="first_stage_model.decoder.mid.block_1.norm2",
|
|
33
|
+
conv_1="first_stage_model.decoder.mid.block_1.conv1",
|
|
34
|
+
conv_2="first_stage_model.decoder.mid.block_1.conv2",
|
|
35
35
|
),
|
|
36
36
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
37
|
-
norm_1="
|
|
38
|
-
norm_2="
|
|
39
|
-
conv_1="
|
|
40
|
-
conv_2="
|
|
37
|
+
norm_1="first_stage_model.decoder.mid.block_2.norm1",
|
|
38
|
+
norm_2="first_stage_model.decoder.mid.block_2.norm2",
|
|
39
|
+
conv_1="first_stage_model.decoder.mid.block_2.conv1",
|
|
40
|
+
conv_2="first_stage_model.decoder.mid.block_2.conv2",
|
|
41
41
|
),
|
|
42
42
|
],
|
|
43
43
|
attention_block_tensor_names=[
|
|
44
44
|
stable_diffusion_loader.AttentionBlockTensorNames(
|
|
45
|
-
norm="
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
norm="first_stage_model.decoder.mid.attn_1.norm",
|
|
46
|
+
q_proj="first_stage_model.decoder.mid.attn_1.q",
|
|
47
|
+
k_proj="first_stage_model.decoder.mid.attn_1.k",
|
|
48
|
+
v_proj="first_stage_model.decoder.mid.attn_1.v",
|
|
49
|
+
output_proj="first_stage_model.decoder.mid.attn_1.proj_out",
|
|
48
50
|
)
|
|
49
51
|
],
|
|
50
52
|
),
|
|
@@ -52,99 +54,99 @@ TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
|
|
|
52
54
|
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
53
55
|
residual_block_tensor_names=[
|
|
54
56
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
55
|
-
norm_1="
|
|
56
|
-
norm_2="
|
|
57
|
-
conv_1="
|
|
58
|
-
conv_2="
|
|
57
|
+
norm_1="first_stage_model.decoder.up.3.block.0.norm1",
|
|
58
|
+
norm_2="first_stage_model.decoder.up.3.block.0.norm2",
|
|
59
|
+
conv_1="first_stage_model.decoder.up.3.block.0.conv1",
|
|
60
|
+
conv_2="first_stage_model.decoder.up.3.block.0.conv2",
|
|
59
61
|
),
|
|
60
62
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
61
|
-
norm_1="
|
|
62
|
-
norm_2="
|
|
63
|
-
conv_1="
|
|
64
|
-
conv_2="
|
|
63
|
+
norm_1="first_stage_model.decoder.up.3.block.1.norm1",
|
|
64
|
+
norm_2="first_stage_model.decoder.up.3.block.1.norm2",
|
|
65
|
+
conv_1="first_stage_model.decoder.up.3.block.1.conv1",
|
|
66
|
+
conv_2="first_stage_model.decoder.up.3.block.1.conv2",
|
|
65
67
|
),
|
|
66
68
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
67
|
-
norm_1="
|
|
68
|
-
norm_2="
|
|
69
|
-
conv_1="
|
|
70
|
-
conv_2="
|
|
69
|
+
norm_1="first_stage_model.decoder.up.3.block.2.norm1",
|
|
70
|
+
norm_2="first_stage_model.decoder.up.3.block.2.norm2",
|
|
71
|
+
conv_1="first_stage_model.decoder.up.3.block.2.conv1",
|
|
72
|
+
conv_2="first_stage_model.decoder.up.3.block.2.conv2",
|
|
71
73
|
),
|
|
72
74
|
],
|
|
73
|
-
upsample_conv="
|
|
75
|
+
upsample_conv="first_stage_model.decoder.up.3.upsample.conv",
|
|
74
76
|
),
|
|
75
77
|
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
76
78
|
residual_block_tensor_names=[
|
|
77
79
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
78
|
-
norm_1="
|
|
79
|
-
norm_2="
|
|
80
|
-
conv_1="
|
|
81
|
-
conv_2="
|
|
80
|
+
norm_1="first_stage_model.decoder.up.2.block.0.norm1",
|
|
81
|
+
norm_2="first_stage_model.decoder.up.2.block.0.norm2",
|
|
82
|
+
conv_1="first_stage_model.decoder.up.2.block.0.conv1",
|
|
83
|
+
conv_2="first_stage_model.decoder.up.2.block.0.conv2",
|
|
82
84
|
),
|
|
83
85
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
84
|
-
norm_1="
|
|
85
|
-
norm_2="
|
|
86
|
-
conv_1="
|
|
87
|
-
conv_2="
|
|
86
|
+
norm_1="first_stage_model.decoder.up.2.block.1.norm1",
|
|
87
|
+
norm_2="first_stage_model.decoder.up.2.block.1.norm2",
|
|
88
|
+
conv_1="first_stage_model.decoder.up.2.block.1.conv1",
|
|
89
|
+
conv_2="first_stage_model.decoder.up.2.block.1.conv2",
|
|
88
90
|
),
|
|
89
91
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
90
|
-
norm_1="
|
|
91
|
-
norm_2="
|
|
92
|
-
conv_1="
|
|
93
|
-
conv_2="
|
|
92
|
+
norm_1="first_stage_model.decoder.up.2.block.2.norm1",
|
|
93
|
+
norm_2="first_stage_model.decoder.up.2.block.2.norm2",
|
|
94
|
+
conv_1="first_stage_model.decoder.up.2.block.2.conv1",
|
|
95
|
+
conv_2="first_stage_model.decoder.up.2.block.2.conv2",
|
|
94
96
|
),
|
|
95
97
|
],
|
|
96
|
-
upsample_conv="
|
|
98
|
+
upsample_conv="first_stage_model.decoder.up.2.upsample.conv",
|
|
97
99
|
),
|
|
98
100
|
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
99
101
|
residual_block_tensor_names=[
|
|
100
102
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
101
|
-
norm_1="
|
|
102
|
-
norm_2="
|
|
103
|
-
conv_1="
|
|
104
|
-
conv_2="
|
|
105
|
-
residual_layer="
|
|
103
|
+
norm_1="first_stage_model.decoder.up.1.block.0.norm1",
|
|
104
|
+
norm_2="first_stage_model.decoder.up.1.block.0.norm2",
|
|
105
|
+
conv_1="first_stage_model.decoder.up.1.block.0.conv1",
|
|
106
|
+
conv_2="first_stage_model.decoder.up.1.block.0.conv2",
|
|
107
|
+
residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
|
|
106
108
|
),
|
|
107
109
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
108
|
-
norm_1="
|
|
109
|
-
norm_2="
|
|
110
|
-
conv_1="
|
|
111
|
-
conv_2="
|
|
110
|
+
norm_1="first_stage_model.decoder.up.1.block.1.norm1",
|
|
111
|
+
norm_2="first_stage_model.decoder.up.1.block.1.norm2",
|
|
112
|
+
conv_1="first_stage_model.decoder.up.1.block.1.conv1",
|
|
113
|
+
conv_2="first_stage_model.decoder.up.1.block.1.conv2",
|
|
112
114
|
),
|
|
113
115
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
114
|
-
norm_1="
|
|
115
|
-
norm_2="
|
|
116
|
-
conv_1="
|
|
117
|
-
conv_2="
|
|
116
|
+
norm_1="first_stage_model.decoder.up.1.block.2.norm1",
|
|
117
|
+
norm_2="first_stage_model.decoder.up.1.block.2.norm2",
|
|
118
|
+
conv_1="first_stage_model.decoder.up.1.block.2.conv1",
|
|
119
|
+
conv_2="first_stage_model.decoder.up.1.block.2.conv2",
|
|
118
120
|
),
|
|
119
121
|
],
|
|
120
|
-
upsample_conv="
|
|
122
|
+
upsample_conv="first_stage_model.decoder.up.1.upsample.conv",
|
|
121
123
|
),
|
|
122
124
|
stable_diffusion_loader.UpDecoderBlockTensorNames(
|
|
123
125
|
residual_block_tensor_names=[
|
|
124
126
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
125
|
-
norm_1="
|
|
126
|
-
norm_2="
|
|
127
|
-
conv_1="
|
|
128
|
-
conv_2="
|
|
129
|
-
residual_layer="
|
|
127
|
+
norm_1="first_stage_model.decoder.up.0.block.0.norm1",
|
|
128
|
+
norm_2="first_stage_model.decoder.up.0.block.0.norm2",
|
|
129
|
+
conv_1="first_stage_model.decoder.up.0.block.0.conv1",
|
|
130
|
+
conv_2="first_stage_model.decoder.up.0.block.0.conv2",
|
|
131
|
+
residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
|
|
130
132
|
),
|
|
131
133
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
132
|
-
norm_1="
|
|
133
|
-
norm_2="
|
|
134
|
-
conv_1="
|
|
135
|
-
conv_2="
|
|
134
|
+
norm_1="first_stage_model.decoder.up.0.block.1.norm1",
|
|
135
|
+
norm_2="first_stage_model.decoder.up.0.block.1.norm2",
|
|
136
|
+
conv_1="first_stage_model.decoder.up.0.block.1.conv1",
|
|
137
|
+
conv_2="first_stage_model.decoder.up.0.block.1.conv2",
|
|
136
138
|
),
|
|
137
139
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
138
|
-
norm_1="
|
|
139
|
-
norm_2="
|
|
140
|
-
conv_1="
|
|
141
|
-
conv_2="
|
|
140
|
+
norm_1="first_stage_model.decoder.up.0.block.2.norm1",
|
|
141
|
+
norm_2="first_stage_model.decoder.up.0.block.2.norm2",
|
|
142
|
+
conv_1="first_stage_model.decoder.up.0.block.2.conv1",
|
|
143
|
+
conv_2="first_stage_model.decoder.up.0.block.2.conv2",
|
|
142
144
|
),
|
|
143
145
|
],
|
|
144
146
|
),
|
|
145
147
|
],
|
|
146
|
-
final_norm="
|
|
147
|
-
conv_out="
|
|
148
|
+
final_norm="first_stage_model.decoder.norm_out",
|
|
149
|
+
conv_out="first_stage_model.decoder.conv_out",
|
|
148
150
|
)
|
|
149
151
|
|
|
150
152
|
|
|
@@ -288,6 +290,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
|
|
|
288
290
|
output_proj_use_bias=True,
|
|
289
291
|
enable_kv_cache=False,
|
|
290
292
|
qkv_transpose_before_split=True,
|
|
293
|
+
qkv_fused_interleaved=False,
|
|
291
294
|
rotary_percentage=0.0,
|
|
292
295
|
),
|
|
293
296
|
)
|
|
@@ -26,12 +26,12 @@ _down_encoder_blocks_tensor_names = [
|
|
|
26
26
|
stable_diffusion_loader.DownEncoderBlockTensorNames(
|
|
27
27
|
residual_block_tensor_names=[
|
|
28
28
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
29
|
-
norm_1=f"
|
|
30
|
-
conv_1=f"
|
|
31
|
-
norm_2=f"
|
|
32
|
-
conv_2=f"
|
|
33
|
-
time_embedding=f"
|
|
34
|
-
residual_layer=f"
|
|
29
|
+
norm_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.0",
|
|
30
|
+
conv_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.2",
|
|
31
|
+
norm_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.0",
|
|
32
|
+
conv_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.3",
|
|
33
|
+
time_embedding=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.emb_layers.1",
|
|
34
|
+
residual_layer=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.skip_connection"
|
|
35
35
|
if (i * 3 + j + 1) in [4, 7]
|
|
36
36
|
else None,
|
|
37
37
|
)
|
|
@@ -39,32 +39,36 @@ _down_encoder_blocks_tensor_names = [
|
|
|
39
39
|
],
|
|
40
40
|
transformer_block_tensor_names=[
|
|
41
41
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
42
|
-
pre_conv_norm=f"
|
|
43
|
-
conv_in=f"
|
|
44
|
-
conv_out=f"
|
|
42
|
+
pre_conv_norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm",
|
|
43
|
+
conv_in=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in",
|
|
44
|
+
conv_out=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out",
|
|
45
45
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
46
|
-
norm=f"
|
|
47
|
-
|
|
48
|
-
|
|
46
|
+
norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
|
|
47
|
+
q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
|
|
48
|
+
k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_k",
|
|
49
|
+
v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_v",
|
|
50
|
+
output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_out.0",
|
|
49
51
|
),
|
|
50
52
|
cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
|
|
51
|
-
norm=f"
|
|
52
|
-
q_proj=f"
|
|
53
|
-
k_proj=f"
|
|
54
|
-
v_proj=f"
|
|
55
|
-
output_proj=f"
|
|
53
|
+
norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm2",
|
|
54
|
+
q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_q",
|
|
55
|
+
k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_k",
|
|
56
|
+
v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_v",
|
|
57
|
+
output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_out.0",
|
|
56
58
|
),
|
|
57
59
|
feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
|
|
58
|
-
norm=f"
|
|
59
|
-
ge_glu=f"
|
|
60
|
-
w2=f"
|
|
60
|
+
norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm3",
|
|
61
|
+
ge_glu=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.0.proj",
|
|
62
|
+
w2=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.2",
|
|
61
63
|
),
|
|
62
64
|
)
|
|
63
65
|
for j in range(2)
|
|
64
66
|
]
|
|
65
67
|
if i < 3
|
|
66
68
|
else None,
|
|
67
|
-
downsample_conv=f"
|
|
69
|
+
downsample_conv=f"model.diffusion_model.input_blocks.{i*3+3}.0.op"
|
|
70
|
+
if i < 3
|
|
71
|
+
else None,
|
|
68
72
|
)
|
|
69
73
|
for i in range(4)
|
|
70
74
|
]
|
|
@@ -72,35 +76,37 @@ _down_encoder_blocks_tensor_names = [
|
|
|
72
76
|
_mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
|
|
73
77
|
residual_block_tensor_names=[
|
|
74
78
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
75
|
-
norm_1=f"
|
|
76
|
-
conv_1=f"
|
|
77
|
-
norm_2=f"
|
|
78
|
-
conv_2=f"
|
|
79
|
-
time_embedding=f"
|
|
79
|
+
norm_1=f"model.diffusion_model.middle_block.{i}.in_layers.0",
|
|
80
|
+
conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
|
|
81
|
+
norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
|
|
82
|
+
conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
|
|
83
|
+
time_embedding=f"model.diffusion_model.middle_block.{i}.emb_layers.1",
|
|
80
84
|
)
|
|
81
85
|
for i in [0, 2]
|
|
82
86
|
],
|
|
83
87
|
transformer_block_tensor_names=[
|
|
84
88
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
85
|
-
pre_conv_norm=f"
|
|
86
|
-
conv_in=f"
|
|
87
|
-
conv_out=f"
|
|
89
|
+
pre_conv_norm=f"model.diffusion_model.middle_block.{i}.norm",
|
|
90
|
+
conv_in=f"model.diffusion_model.middle_block.{i}.proj_in",
|
|
91
|
+
conv_out=f"model.diffusion_model.middle_block.{i}.proj_out",
|
|
88
92
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
89
|
-
norm=f"
|
|
90
|
-
|
|
91
|
-
|
|
93
|
+
norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1",
|
|
94
|
+
q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_q",
|
|
95
|
+
k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_k",
|
|
96
|
+
v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_v",
|
|
97
|
+
output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0",
|
|
92
98
|
),
|
|
93
99
|
cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
|
|
94
|
-
norm=f"
|
|
95
|
-
q_proj=f"
|
|
96
|
-
k_proj=f"
|
|
97
|
-
v_proj=f"
|
|
98
|
-
output_proj=f"
|
|
100
|
+
norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2",
|
|
101
|
+
q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_q",
|
|
102
|
+
k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_k",
|
|
103
|
+
v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_v",
|
|
104
|
+
output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0",
|
|
99
105
|
),
|
|
100
106
|
feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
|
|
101
|
-
norm=f"
|
|
102
|
-
ge_glu=f"
|
|
103
|
-
w2=f"
|
|
107
|
+
norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3",
|
|
108
|
+
ge_glu=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj",
|
|
109
|
+
w2=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2",
|
|
104
110
|
),
|
|
105
111
|
)
|
|
106
112
|
for i in [1]
|
|
@@ -111,58 +117,59 @@ _up_decoder_blocks_tensor_names = [
|
|
|
111
117
|
stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
|
|
112
118
|
residual_block_tensor_names=[
|
|
113
119
|
stable_diffusion_loader.ResidualBlockTensorNames(
|
|
114
|
-
norm_1=f"
|
|
115
|
-
conv_1=f"
|
|
116
|
-
norm_2=f"
|
|
117
|
-
conv_2=f"
|
|
118
|
-
time_embedding=f"
|
|
119
|
-
residual_layer=f"
|
|
120
|
+
norm_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0",
|
|
121
|
+
conv_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2",
|
|
122
|
+
norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
|
|
123
|
+
conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
|
|
124
|
+
time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
|
|
125
|
+
residual_layer=f"model.diffusion_model.output_blocks.{i*3+j}.0.skip_connection",
|
|
120
126
|
)
|
|
121
127
|
for j in range(3)
|
|
122
128
|
],
|
|
123
129
|
transformer_block_tensor_names=[
|
|
124
130
|
stable_diffusion_loader.TransformerBlockTensorNames(
|
|
125
|
-
pre_conv_norm=f"
|
|
126
|
-
conv_in=f"
|
|
127
|
-
conv_out=f"
|
|
131
|
+
pre_conv_norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.norm",
|
|
132
|
+
conv_in=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in",
|
|
133
|
+
conv_out=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out",
|
|
128
134
|
self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
|
|
129
|
-
norm=f"
|
|
130
|
-
|
|
131
|
-
|
|
135
|
+
norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
|
|
136
|
+
q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
|
|
137
|
+
k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_k",
|
|
138
|
+
v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_v",
|
|
139
|
+
output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_out.0",
|
|
132
140
|
),
|
|
133
141
|
cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
|
|
134
|
-
norm=f"
|
|
135
|
-
q_proj=f"
|
|
136
|
-
k_proj=f"
|
|
137
|
-
v_proj=f"
|
|
138
|
-
output_proj=f"
|
|
142
|
+
norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm2",
|
|
143
|
+
q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_q",
|
|
144
|
+
k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_k",
|
|
145
|
+
v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_v",
|
|
146
|
+
output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_out.0",
|
|
139
147
|
),
|
|
140
148
|
feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
|
|
141
|
-
norm=f"
|
|
142
|
-
ge_glu=f"
|
|
143
|
-
w2=f"
|
|
149
|
+
norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm3",
|
|
150
|
+
ge_glu=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.0.proj",
|
|
151
|
+
w2=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.2",
|
|
144
152
|
),
|
|
145
153
|
)
|
|
146
154
|
for j in range(3)
|
|
147
155
|
]
|
|
148
156
|
if i > 0
|
|
149
157
|
else None,
|
|
150
|
-
upsample_conv=f"
|
|
158
|
+
upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
|
|
151
159
|
if 0 < i < 3
|
|
152
|
-
else (f"
|
|
160
|
+
else (f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None),
|
|
153
161
|
)
|
|
154
162
|
for i in range(4)
|
|
155
163
|
]
|
|
156
164
|
|
|
157
|
-
|
|
158
|
-
TENSORS_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
|
|
165
|
+
TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
|
|
159
166
|
time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames(
|
|
160
|
-
w1="
|
|
161
|
-
w2="
|
|
167
|
+
w1="model.diffusion_model.time_embed.0",
|
|
168
|
+
w2="model.diffusion_model.time_embed.2",
|
|
162
169
|
),
|
|
163
|
-
conv_in="
|
|
164
|
-
conv_out="
|
|
165
|
-
final_norm="
|
|
170
|
+
conv_in="model.diffusion_model.input_blocks.0.0",
|
|
171
|
+
conv_out="model.diffusion_model.out.2",
|
|
172
|
+
final_norm="model.diffusion_model.out.0",
|
|
166
173
|
down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names,
|
|
167
174
|
mid_block_tensor_names=_mid_block_tensor_names,
|
|
168
175
|
up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names,
|
|
@@ -249,6 +256,7 @@ class Diffusion(nn.Module):
|
|
|
249
256
|
qkv_use_bias=False,
|
|
250
257
|
output_proj_use_bias=True,
|
|
251
258
|
enable_kv_cache=False,
|
|
259
|
+
qkv_fused_interleaved=False,
|
|
252
260
|
)
|
|
253
261
|
|
|
254
262
|
# Down encoders.
|
|
@@ -280,7 +288,7 @@ class Diffusion(nn.Module):
|
|
|
280
288
|
stride=2,
|
|
281
289
|
padding=config.downsample_padding,
|
|
282
290
|
),
|
|
283
|
-
transformer_block_config=unet_cfg.
|
|
291
|
+
transformer_block_config=unet_cfg.TransformerBlock2DConfig(
|
|
284
292
|
attention_block_config=unet_cfg.AttentionBlock2DConfig(
|
|
285
293
|
dim=output_channel,
|
|
286
294
|
attention_batch_size=config.transformer_batch_size,
|
|
@@ -340,7 +348,7 @@ class Diffusion(nn.Module):
|
|
|
340
348
|
),
|
|
341
349
|
num_layers=config.mid_block_layers,
|
|
342
350
|
time_embedding_channels=config.time_embedding_blocks_dim,
|
|
343
|
-
transformer_block_config=unet_cfg.
|
|
351
|
+
transformer_block_config=unet_cfg.TransformerBlock2DConfig(
|
|
344
352
|
attention_block_config=unet_cfg.AttentionBlock2DConfig(
|
|
345
353
|
dim=mid_block_channels,
|
|
346
354
|
attention_batch_size=config.transformer_batch_size,
|
|
@@ -401,7 +409,7 @@ class Diffusion(nn.Module):
|
|
|
401
409
|
mode=unet_cfg.SamplingType.NEAREST,
|
|
402
410
|
scale_factor=2,
|
|
403
411
|
),
|
|
404
|
-
transformer_block_config=unet_cfg.
|
|
412
|
+
transformer_block_config=unet_cfg.TransformerBlock2DConfig(
|
|
405
413
|
attention_block_config=unet_cfg.AttentionBlock2DConfig(
|
|
406
414
|
dim=output_channel,
|
|
407
415
|
attention_batch_size=config.transformer_batch_size,
|
|
@@ -167,7 +167,7 @@ def run_tflite_pipeline(
|
|
|
167
167
|
if input_image:
|
|
168
168
|
if not hasattr(model, 'encoder'):
|
|
169
169
|
raise AttributeError(
|
|
170
|
-
'Stable Diffusion must be
|
|
170
|
+
'Stable Diffusion must be initialized with encoder to accept input_image.'
|
|
171
171
|
)
|
|
172
172
|
input_image = input_image.resize((width, height))
|
|
173
173
|
input_image_np = np.array(input_image).astype(np.float32)
|
|
@@ -27,6 +27,8 @@ import ai_edge_torch.generative.layers.model_config as cfg
|
|
|
27
27
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
|
|
28
28
|
from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
|
|
29
29
|
|
|
30
|
+
BATCH_SIZE = 1
|
|
31
|
+
|
|
30
32
|
|
|
31
33
|
class EncoderDecoderBlock(nn.Module):
|
|
32
34
|
|
|
@@ -44,6 +46,7 @@ class EncoderDecoderBlock(nn.Module):
|
|
|
44
46
|
|
|
45
47
|
super().__init__()
|
|
46
48
|
self.atten_func = T5Attention(
|
|
49
|
+
BATCH_SIZE,
|
|
47
50
|
config.embedding_dim,
|
|
48
51
|
config.attn_config,
|
|
49
52
|
config.pre_attention_norm_config,
|
|
@@ -54,6 +57,7 @@ class EncoderDecoderBlock(nn.Module):
|
|
|
54
57
|
# For a decoder, we add a cross attention.
|
|
55
58
|
if config.is_decoder:
|
|
56
59
|
self.cross_atten_func = T5Attention(
|
|
60
|
+
BATCH_SIZE,
|
|
57
61
|
config.embedding_dim,
|
|
58
62
|
config.attn_config,
|
|
59
63
|
config.pre_attention_norm_config,
|
|
@@ -127,6 +131,7 @@ class T5Attention(CrossAttention):
|
|
|
127
131
|
|
|
128
132
|
def __init__(
|
|
129
133
|
self,
|
|
134
|
+
batch: int,
|
|
130
135
|
dim: int,
|
|
131
136
|
config: cfg.AttentionConfig,
|
|
132
137
|
norm_config: cfg.NormalizationConfig,
|
|
@@ -144,7 +149,7 @@ class T5Attention(CrossAttention):
|
|
|
144
149
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
|
145
150
|
has_relative_attention_bias (bool): whether we compute relative bias.
|
|
146
151
|
"""
|
|
147
|
-
super().__init__(dim, dim, config, kv_cache_max, enable_hlfb)
|
|
152
|
+
super().__init__(batch, dim, dim, config, kv_cache_max, enable_hlfb)
|
|
148
153
|
self.pre_atten_norm = builder.build_norm(dim, norm_config)
|
|
149
154
|
|
|
150
155
|
self.has_relative_attention_bias = has_relative_attention_bias
|
|
@@ -40,7 +40,7 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
|
|
|
40
40
|
if self.is_zero_tensor_node(source):
|
|
41
41
|
# Remove the mark_tensor call on the mask input by
|
|
42
42
|
# replacing the target with an identity function.
|
|
43
|
-
node.target = lambda *args, **kwargs: args[0]
|
|
43
|
+
node.target = lambda *args, **kwargs: torch.zeros_like(args[0])
|
|
44
44
|
|
|
45
45
|
exported_program.graph_module.graph.lint()
|
|
46
46
|
exported_program.graph_module.recompile()
|
|
@@ -68,6 +68,10 @@ class AttentionConfig:
|
|
|
68
68
|
qkv_transpose_before_split: bool = False
|
|
69
69
|
# Whether to use bias with Query, Key, and Value projection.
|
|
70
70
|
qkv_use_bias: bool = False
|
|
71
|
+
# Whether the fused q, k, v projection weights interleaves q, k, v heads.
|
|
72
|
+
# If True, the projection weights are in format [q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]
|
|
73
|
+
# If False, the projection weights are in format [q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]
|
|
74
|
+
qkv_fused_interleaved: bool = True
|
|
71
75
|
# Whether to use bias with attention output projection.
|
|
72
76
|
output_proj_use_bias: bool = False
|
|
73
77
|
enable_kv_cache: bool = True
|
|
@@ -272,7 +272,7 @@ class TransformerBlock2D(nn.Module):
|
|
|
272
272
|
|
|
273
273
|
"""
|
|
274
274
|
|
|
275
|
-
def __init__(self, config: unet_cfg.
|
|
275
|
+
def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
|
|
276
276
|
"""Initialize an instance of the TransformerBlock2D.
|
|
277
277
|
|
|
278
278
|
Args:
|
|
@@ -85,7 +85,7 @@ class FeedForwardBlock2DConfig:
|
|
|
85
85
|
|
|
86
86
|
|
|
87
87
|
@dataclass
|
|
88
|
-
class
|
|
88
|
+
class TransformerBlock2DConfig:
|
|
89
89
|
pre_conv_normalization_config: layers_cfg.NormalizationConfig
|
|
90
90
|
attention_block_config: AttentionBlock2DConfig
|
|
91
91
|
cross_attention_block_config: CrossAttentionBlock2DConfig
|
|
@@ -108,7 +108,7 @@ class UpDecoderBlock2DConfig:
|
|
|
108
108
|
# Optional sampling config if add_upsample is True.
|
|
109
109
|
sampling_config: Optional[UpSamplingConfig] = None
|
|
110
110
|
# Optional config of transformer blocks interleaved with residual blocks
|
|
111
|
-
transformer_block_config: Optional[
|
|
111
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
112
112
|
# Optional dimension of context tensor if context tensor is given as input.
|
|
113
113
|
context_dim: Optional[int] = None
|
|
114
114
|
|
|
@@ -131,7 +131,7 @@ class SkipUpDecoderBlock2DConfig:
|
|
|
131
131
|
# Optional sampling config if add_upsample is True.
|
|
132
132
|
sampling_config: Optional[UpSamplingConfig] = None
|
|
133
133
|
# Optional config of transformer blocks interleaved with residual blocks
|
|
134
|
-
transformer_block_config: Optional[
|
|
134
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
135
135
|
# Optional dimension of context tensor if context tensor is given as input.
|
|
136
136
|
context_dim: Optional[int] = None
|
|
137
137
|
|
|
@@ -152,7 +152,7 @@ class DownEncoderBlock2DConfig:
|
|
|
152
152
|
# Optional sampling config if add_upsample is True.
|
|
153
153
|
sampling_config: Optional[DownSamplingConfig] = None
|
|
154
154
|
# Optional config of transformer blocks interleaved with residual blocks
|
|
155
|
-
transformer_block_config: Optional[
|
|
155
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
156
156
|
# Optional dimension of context tensor if context tensor is given as input.
|
|
157
157
|
context_dim: Optional[int] = None
|
|
158
158
|
|
|
@@ -168,7 +168,7 @@ class MidBlock2DConfig:
|
|
|
168
168
|
# Optional config of attention blocks interleaved with residual blocks
|
|
169
169
|
attention_block_config: Optional[AttentionBlock2DConfig] = None
|
|
170
170
|
# Optional config of transformer blocks interleaved with residual blocks
|
|
171
|
-
transformer_block_config: Optional[
|
|
171
|
+
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
|
172
172
|
# Optional dimension of context tensor if context tensor is given as input.
|
|
173
173
|
context_dim: Optional[int] = None
|
|
174
174
|
|
|
@@ -317,9 +317,12 @@ class ModelLoader:
|
|
|
317
317
|
k: torch.Tensor,
|
|
318
318
|
v: torch.Tensor,
|
|
319
319
|
) -> torch.Tensor:
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
320
|
+
if config.attn_config.qkv_fused_interleaved:
|
|
321
|
+
q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
|
|
322
|
+
qs = torch.split(q, config.head_dim * q_per_kv)
|
|
323
|
+
ks = torch.split(k, config.head_dim)
|
|
324
|
+
vs = torch.split(v, config.head_dim)
|
|
325
|
+
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
|
326
|
+
return torch.cat(cycled)
|
|
327
|
+
else:
|
|
328
|
+
return torch.cat([q, k, v], dim=0)
|