ai-edge-torch-nightly 0.2.0.dev20240710__py3-none-any.whl → 0.2.0.dev20240712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (23) hide show
  1. ai_edge_torch/convert/conversion.py +2 -4
  2. ai_edge_torch/convert/conversion_utils.py +61 -3
  3. ai_edge_torch/convert/converter.py +47 -16
  4. ai_edge_torch/convert/test/test_convert.py +39 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -10
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +56 -30
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +72 -69
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +80 -72
  9. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +1 -1
  10. ai_edge_torch/generative/examples/t5/t5_attention.py +6 -1
  11. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +1 -1
  12. ai_edge_torch/generative/layers/model_config.py +4 -0
  13. ai_edge_torch/generative/layers/unet/blocks_2d.py +1 -1
  14. ai_edge_torch/generative/layers/unet/model_config.py +5 -5
  15. ai_edge_torch/generative/utilities/loader.py +9 -6
  16. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +74 -10
  17. ai_edge_torch/model.py +11 -3
  18. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -13
  19. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/RECORD +23 -23
  21. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/top_level.txt +0 -0
@@ -22,29 +22,31 @@ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
22
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
23
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
24
24
 
25
- TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
- post_quant_conv="0",
27
- conv_in="1",
25
+ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
+ post_quant_conv="first_stage_model.post_quant_conv",
27
+ conv_in="first_stage_model.decoder.conv_in",
28
28
  mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
29
29
  residual_block_tensor_names=[
30
30
  stable_diffusion_loader.ResidualBlockTensorNames(
31
- norm_1="2.groupnorm_1",
32
- norm_2="2.groupnorm_2",
33
- conv_1="2.conv_1",
34
- conv_2="2.conv_2",
31
+ norm_1="first_stage_model.decoder.mid.block_1.norm1",
32
+ norm_2="first_stage_model.decoder.mid.block_1.norm2",
33
+ conv_1="first_stage_model.decoder.mid.block_1.conv1",
34
+ conv_2="first_stage_model.decoder.mid.block_1.conv2",
35
35
  ),
36
36
  stable_diffusion_loader.ResidualBlockTensorNames(
37
- norm_1="4.groupnorm_1",
38
- norm_2="4.groupnorm_2",
39
- conv_1="4.conv_1",
40
- conv_2="4.conv_2",
37
+ norm_1="first_stage_model.decoder.mid.block_2.norm1",
38
+ norm_2="first_stage_model.decoder.mid.block_2.norm2",
39
+ conv_1="first_stage_model.decoder.mid.block_2.conv1",
40
+ conv_2="first_stage_model.decoder.mid.block_2.conv2",
41
41
  ),
42
42
  ],
43
43
  attention_block_tensor_names=[
44
44
  stable_diffusion_loader.AttentionBlockTensorNames(
45
- norm="3.groupnorm",
46
- fused_qkv_proj="3.attention.in_proj",
47
- output_proj="3.attention.out_proj",
45
+ norm="first_stage_model.decoder.mid.attn_1.norm",
46
+ q_proj="first_stage_model.decoder.mid.attn_1.q",
47
+ k_proj="first_stage_model.decoder.mid.attn_1.k",
48
+ v_proj="first_stage_model.decoder.mid.attn_1.v",
49
+ output_proj="first_stage_model.decoder.mid.attn_1.proj_out",
48
50
  )
49
51
  ],
50
52
  ),
@@ -52,99 +54,99 @@ TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
52
54
  stable_diffusion_loader.UpDecoderBlockTensorNames(
53
55
  residual_block_tensor_names=[
54
56
  stable_diffusion_loader.ResidualBlockTensorNames(
55
- norm_1="5.groupnorm_1",
56
- norm_2="5.groupnorm_2",
57
- conv_1="5.conv_1",
58
- conv_2="5.conv_2",
57
+ norm_1="first_stage_model.decoder.up.3.block.0.norm1",
58
+ norm_2="first_stage_model.decoder.up.3.block.0.norm2",
59
+ conv_1="first_stage_model.decoder.up.3.block.0.conv1",
60
+ conv_2="first_stage_model.decoder.up.3.block.0.conv2",
59
61
  ),
60
62
  stable_diffusion_loader.ResidualBlockTensorNames(
61
- norm_1="6.groupnorm_1",
62
- norm_2="6.groupnorm_2",
63
- conv_1="6.conv_1",
64
- conv_2="6.conv_2",
63
+ norm_1="first_stage_model.decoder.up.3.block.1.norm1",
64
+ norm_2="first_stage_model.decoder.up.3.block.1.norm2",
65
+ conv_1="first_stage_model.decoder.up.3.block.1.conv1",
66
+ conv_2="first_stage_model.decoder.up.3.block.1.conv2",
65
67
  ),
66
68
  stable_diffusion_loader.ResidualBlockTensorNames(
67
- norm_1="7.groupnorm_1",
68
- norm_2="7.groupnorm_2",
69
- conv_1="7.conv_1",
70
- conv_2="7.conv_2",
69
+ norm_1="first_stage_model.decoder.up.3.block.2.norm1",
70
+ norm_2="first_stage_model.decoder.up.3.block.2.norm2",
71
+ conv_1="first_stage_model.decoder.up.3.block.2.conv1",
72
+ conv_2="first_stage_model.decoder.up.3.block.2.conv2",
71
73
  ),
72
74
  ],
73
- upsample_conv="9",
75
+ upsample_conv="first_stage_model.decoder.up.3.upsample.conv",
74
76
  ),
75
77
  stable_diffusion_loader.UpDecoderBlockTensorNames(
76
78
  residual_block_tensor_names=[
77
79
  stable_diffusion_loader.ResidualBlockTensorNames(
78
- norm_1="10.groupnorm_1",
79
- norm_2="10.groupnorm_2",
80
- conv_1="10.conv_1",
81
- conv_2="10.conv_2",
80
+ norm_1="first_stage_model.decoder.up.2.block.0.norm1",
81
+ norm_2="first_stage_model.decoder.up.2.block.0.norm2",
82
+ conv_1="first_stage_model.decoder.up.2.block.0.conv1",
83
+ conv_2="first_stage_model.decoder.up.2.block.0.conv2",
82
84
  ),
83
85
  stable_diffusion_loader.ResidualBlockTensorNames(
84
- norm_1="11.groupnorm_1",
85
- norm_2="11.groupnorm_2",
86
- conv_1="11.conv_1",
87
- conv_2="11.conv_2",
86
+ norm_1="first_stage_model.decoder.up.2.block.1.norm1",
87
+ norm_2="first_stage_model.decoder.up.2.block.1.norm2",
88
+ conv_1="first_stage_model.decoder.up.2.block.1.conv1",
89
+ conv_2="first_stage_model.decoder.up.2.block.1.conv2",
88
90
  ),
89
91
  stable_diffusion_loader.ResidualBlockTensorNames(
90
- norm_1="12.groupnorm_1",
91
- norm_2="12.groupnorm_2",
92
- conv_1="12.conv_1",
93
- conv_2="12.conv_2",
92
+ norm_1="first_stage_model.decoder.up.2.block.2.norm1",
93
+ norm_2="first_stage_model.decoder.up.2.block.2.norm2",
94
+ conv_1="first_stage_model.decoder.up.2.block.2.conv1",
95
+ conv_2="first_stage_model.decoder.up.2.block.2.conv2",
94
96
  ),
95
97
  ],
96
- upsample_conv="14",
98
+ upsample_conv="first_stage_model.decoder.up.2.upsample.conv",
97
99
  ),
98
100
  stable_diffusion_loader.UpDecoderBlockTensorNames(
99
101
  residual_block_tensor_names=[
100
102
  stable_diffusion_loader.ResidualBlockTensorNames(
101
- norm_1="15.groupnorm_1",
102
- norm_2="15.groupnorm_2",
103
- conv_1="15.conv_1",
104
- conv_2="15.conv_2",
105
- residual_layer="15.residual_layer",
103
+ norm_1="first_stage_model.decoder.up.1.block.0.norm1",
104
+ norm_2="first_stage_model.decoder.up.1.block.0.norm2",
105
+ conv_1="first_stage_model.decoder.up.1.block.0.conv1",
106
+ conv_2="first_stage_model.decoder.up.1.block.0.conv2",
107
+ residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
106
108
  ),
107
109
  stable_diffusion_loader.ResidualBlockTensorNames(
108
- norm_1="16.groupnorm_1",
109
- norm_2="16.groupnorm_2",
110
- conv_1="16.conv_1",
111
- conv_2="16.conv_2",
110
+ norm_1="first_stage_model.decoder.up.1.block.1.norm1",
111
+ norm_2="first_stage_model.decoder.up.1.block.1.norm2",
112
+ conv_1="first_stage_model.decoder.up.1.block.1.conv1",
113
+ conv_2="first_stage_model.decoder.up.1.block.1.conv2",
112
114
  ),
113
115
  stable_diffusion_loader.ResidualBlockTensorNames(
114
- norm_1="17.groupnorm_1",
115
- norm_2="17.groupnorm_2",
116
- conv_1="17.conv_1",
117
- conv_2="17.conv_2",
116
+ norm_1="first_stage_model.decoder.up.1.block.2.norm1",
117
+ norm_2="first_stage_model.decoder.up.1.block.2.norm2",
118
+ conv_1="first_stage_model.decoder.up.1.block.2.conv1",
119
+ conv_2="first_stage_model.decoder.up.1.block.2.conv2",
118
120
  ),
119
121
  ],
120
- upsample_conv="19",
122
+ upsample_conv="first_stage_model.decoder.up.1.upsample.conv",
121
123
  ),
122
124
  stable_diffusion_loader.UpDecoderBlockTensorNames(
123
125
  residual_block_tensor_names=[
124
126
  stable_diffusion_loader.ResidualBlockTensorNames(
125
- norm_1="20.groupnorm_1",
126
- norm_2="20.groupnorm_2",
127
- conv_1="20.conv_1",
128
- conv_2="20.conv_2",
129
- residual_layer="20.residual_layer",
127
+ norm_1="first_stage_model.decoder.up.0.block.0.norm1",
128
+ norm_2="first_stage_model.decoder.up.0.block.0.norm2",
129
+ conv_1="first_stage_model.decoder.up.0.block.0.conv1",
130
+ conv_2="first_stage_model.decoder.up.0.block.0.conv2",
131
+ residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
130
132
  ),
131
133
  stable_diffusion_loader.ResidualBlockTensorNames(
132
- norm_1="21.groupnorm_1",
133
- norm_2="21.groupnorm_2",
134
- conv_1="21.conv_1",
135
- conv_2="21.conv_2",
134
+ norm_1="first_stage_model.decoder.up.0.block.1.norm1",
135
+ norm_2="first_stage_model.decoder.up.0.block.1.norm2",
136
+ conv_1="first_stage_model.decoder.up.0.block.1.conv1",
137
+ conv_2="first_stage_model.decoder.up.0.block.1.conv2",
136
138
  ),
137
139
  stable_diffusion_loader.ResidualBlockTensorNames(
138
- norm_1="22.groupnorm_1",
139
- norm_2="22.groupnorm_2",
140
- conv_1="22.conv_1",
141
- conv_2="22.conv_2",
140
+ norm_1="first_stage_model.decoder.up.0.block.2.norm1",
141
+ norm_2="first_stage_model.decoder.up.0.block.2.norm2",
142
+ conv_1="first_stage_model.decoder.up.0.block.2.conv1",
143
+ conv_2="first_stage_model.decoder.up.0.block.2.conv2",
142
144
  ),
143
145
  ],
144
146
  ),
145
147
  ],
146
- final_norm="23",
147
- conv_out="25",
148
+ final_norm="first_stage_model.decoder.norm_out",
149
+ conv_out="first_stage_model.decoder.conv_out",
148
150
  )
149
151
 
150
152
 
@@ -288,6 +290,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
288
290
  output_proj_use_bias=True,
289
291
  enable_kv_cache=False,
290
292
  qkv_transpose_before_split=True,
293
+ qkv_fused_interleaved=False,
291
294
  rotary_percentage=0.0,
292
295
  ),
293
296
  )
@@ -26,12 +26,12 @@ _down_encoder_blocks_tensor_names = [
26
26
  stable_diffusion_loader.DownEncoderBlockTensorNames(
27
27
  residual_block_tensor_names=[
28
28
  stable_diffusion_loader.ResidualBlockTensorNames(
29
- norm_1=f"unet.encoders.{i*3+j+1}.0.groupnorm_feature",
30
- conv_1=f"unet.encoders.{i*3+j+1}.0.conv_feature",
31
- norm_2=f"unet.encoders.{i*3+j+1}.0.groupnorm_merged",
32
- conv_2=f"unet.encoders.{i*3+j+1}.0.conv_merged",
33
- time_embedding=f"unet.encoders.{i*3+j+1}.0.linear_time",
34
- residual_layer=f"unet.encoders.{i*3+j+1}.0.residual_layer"
29
+ norm_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.0",
30
+ conv_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.2",
31
+ norm_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.0",
32
+ conv_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.3",
33
+ time_embedding=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.emb_layers.1",
34
+ residual_layer=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.skip_connection"
35
35
  if (i * 3 + j + 1) in [4, 7]
36
36
  else None,
37
37
  )
@@ -39,32 +39,36 @@ _down_encoder_blocks_tensor_names = [
39
39
  ],
40
40
  transformer_block_tensor_names=[
41
41
  stable_diffusion_loader.TransformerBlockTensorNames(
42
- pre_conv_norm=f"unet.encoders.{i*3+j+1}.1.groupnorm",
43
- conv_in=f"unet.encoders.{i*3+j+1}.1.conv_input",
44
- conv_out=f"unet.encoders.{i*3+j+1}.1.conv_output",
42
+ pre_conv_norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm",
43
+ conv_in=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in",
44
+ conv_out=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out",
45
45
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
46
- norm=f"unet.encoders.{i*3+j+1}.1.layernorm_1",
47
- fused_qkv_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.in_proj",
48
- output_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.out_proj",
46
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
47
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
48
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_k",
49
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_v",
50
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_out.0",
49
51
  ),
50
52
  cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
51
- norm=f"unet.encoders.{i*3+j+1}.1.layernorm_2",
52
- q_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.q_proj",
53
- k_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.k_proj",
54
- v_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.v_proj",
55
- output_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.out_proj",
53
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm2",
54
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_q",
55
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_k",
56
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_v",
57
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_out.0",
56
58
  ),
57
59
  feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
58
- norm=f"unet.encoders.{i*3+j+1}.1.layernorm_3",
59
- ge_glu=f"unet.encoders.{i*3+j+1}.1.linear_geglu_1",
60
- w2=f"unet.encoders.{i*3+j+1}.1.linear_geglu_2",
60
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm3",
61
+ ge_glu=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.0.proj",
62
+ w2=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.2",
61
63
  ),
62
64
  )
63
65
  for j in range(2)
64
66
  ]
65
67
  if i < 3
66
68
  else None,
67
- downsample_conv=f"unet.encoders.{i*3+3}.0" if i < 3 else None,
69
+ downsample_conv=f"model.diffusion_model.input_blocks.{i*3+3}.0.op"
70
+ if i < 3
71
+ else None,
68
72
  )
69
73
  for i in range(4)
70
74
  ]
@@ -72,35 +76,37 @@ _down_encoder_blocks_tensor_names = [
72
76
  _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
73
77
  residual_block_tensor_names=[
74
78
  stable_diffusion_loader.ResidualBlockTensorNames(
75
- norm_1=f"unet.bottleneck.{i}.groupnorm_feature",
76
- conv_1=f"unet.bottleneck.{i}.conv_feature",
77
- norm_2=f"unet.bottleneck.{i}.groupnorm_merged",
78
- conv_2=f"unet.bottleneck.{i}.conv_merged",
79
- time_embedding=f"unet.bottleneck.{i}.linear_time",
79
+ norm_1=f"model.diffusion_model.middle_block.{i}.in_layers.0",
80
+ conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
81
+ norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
82
+ conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
83
+ time_embedding=f"model.diffusion_model.middle_block.{i}.emb_layers.1",
80
84
  )
81
85
  for i in [0, 2]
82
86
  ],
83
87
  transformer_block_tensor_names=[
84
88
  stable_diffusion_loader.TransformerBlockTensorNames(
85
- pre_conv_norm=f"unet.bottleneck.{i}.groupnorm",
86
- conv_in=f"unet.bottleneck.{i}.conv_input",
87
- conv_out=f"unet.bottleneck.{i}.conv_output",
89
+ pre_conv_norm=f"model.diffusion_model.middle_block.{i}.norm",
90
+ conv_in=f"model.diffusion_model.middle_block.{i}.proj_in",
91
+ conv_out=f"model.diffusion_model.middle_block.{i}.proj_out",
88
92
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
89
- norm=f"unet.bottleneck.{i}.layernorm_1",
90
- fused_qkv_proj=f"unet.bottleneck.{i}.attention_1.in_proj",
91
- output_proj=f"unet.bottleneck.{i}.attention_1.out_proj",
93
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1",
94
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_q",
95
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_k",
96
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_v",
97
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0",
92
98
  ),
93
99
  cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
94
- norm=f"unet.bottleneck.{i}.layernorm_2",
95
- q_proj=f"unet.bottleneck.{i}.attention_2.q_proj",
96
- k_proj=f"unet.bottleneck.{i}.attention_2.k_proj",
97
- v_proj=f"unet.bottleneck.{i}.attention_2.v_proj",
98
- output_proj=f"unet.bottleneck.{i}.attention_2.out_proj",
100
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2",
101
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_q",
102
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_k",
103
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_v",
104
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0",
99
105
  ),
100
106
  feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
101
- norm=f"unet.bottleneck.{i}.layernorm_3",
102
- ge_glu=f"unet.bottleneck.{i}.linear_geglu_1",
103
- w2=f"unet.bottleneck.{i}.linear_geglu_2",
107
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3",
108
+ ge_glu=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj",
109
+ w2=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2",
104
110
  ),
105
111
  )
106
112
  for i in [1]
@@ -111,58 +117,59 @@ _up_decoder_blocks_tensor_names = [
111
117
  stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
112
118
  residual_block_tensor_names=[
113
119
  stable_diffusion_loader.ResidualBlockTensorNames(
114
- norm_1=f"unet.decoders.{i*3+j}.0.groupnorm_feature",
115
- conv_1=f"unet.decoders.{i*3+j}.0.conv_feature",
116
- norm_2=f"unet.decoders.{i*3+j}.0.groupnorm_merged",
117
- conv_2=f"unet.decoders.{i*3+j}.0.conv_merged",
118
- time_embedding=f"unet.decoders.{i*3+j}.0.linear_time",
119
- residual_layer=f"unet.decoders.{i*3+j}.0.residual_layer",
120
+ norm_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0",
121
+ conv_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2",
122
+ norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
123
+ conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
124
+ time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
125
+ residual_layer=f"model.diffusion_model.output_blocks.{i*3+j}.0.skip_connection",
120
126
  )
121
127
  for j in range(3)
122
128
  ],
123
129
  transformer_block_tensor_names=[
124
130
  stable_diffusion_loader.TransformerBlockTensorNames(
125
- pre_conv_norm=f"unet.decoders.{i*3+j}.1.groupnorm",
126
- conv_in=f"unet.decoders.{i*3+j}.1.conv_input",
127
- conv_out=f"unet.decoders.{i*3+j}.1.conv_output",
131
+ pre_conv_norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.norm",
132
+ conv_in=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in",
133
+ conv_out=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out",
128
134
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
129
- norm=f"unet.decoders.{i*3+j}.1.layernorm_1",
130
- fused_qkv_proj=f"unet.decoders.{i*3+j}.1.attention_1.in_proj",
131
- output_proj=f"unet.decoders.{i*3+j}.1.attention_1.out_proj",
135
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
136
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
137
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_k",
138
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_v",
139
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_out.0",
132
140
  ),
133
141
  cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
134
- norm=f"unet.decoders.{i*3+j}.1.layernorm_2",
135
- q_proj=f"unet.decoders.{i*3+j}.1.attention_2.q_proj",
136
- k_proj=f"unet.decoders.{i*3+j}.1.attention_2.k_proj",
137
- v_proj=f"unet.decoders.{i*3+j}.1.attention_2.v_proj",
138
- output_proj=f"unet.decoders.{i*3+j}.1.attention_2.out_proj",
142
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm2",
143
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_q",
144
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_k",
145
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_v",
146
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_out.0",
139
147
  ),
140
148
  feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
141
- norm=f"unet.decoders.{i*3+j}.1.layernorm_3",
142
- ge_glu=f"unet.decoders.{i*3+j}.1.linear_geglu_1",
143
- w2=f"unet.decoders.{i*3+j}.1.linear_geglu_2",
149
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm3",
150
+ ge_glu=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.0.proj",
151
+ w2=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.2",
144
152
  ),
145
153
  )
146
154
  for j in range(3)
147
155
  ]
148
156
  if i > 0
149
157
  else None,
150
- upsample_conv=f"unet.decoders.{i*3+2}.2.conv"
158
+ upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
151
159
  if 0 < i < 3
152
- else (f"unet.decoders.2.1.conv" if i == 0 else None),
160
+ else (f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None),
153
161
  )
154
162
  for i in range(4)
155
163
  ]
156
164
 
157
-
158
- TENSORS_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
165
+ TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
159
166
  time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames(
160
- w1="time_embedding.linear_1",
161
- w2="time_embedding.linear_2",
167
+ w1="model.diffusion_model.time_embed.0",
168
+ w2="model.diffusion_model.time_embed.2",
162
169
  ),
163
- conv_in="unet.encoders.0.0",
164
- conv_out="final.conv",
165
- final_norm="final.groupnorm",
170
+ conv_in="model.diffusion_model.input_blocks.0.0",
171
+ conv_out="model.diffusion_model.out.2",
172
+ final_norm="model.diffusion_model.out.0",
166
173
  down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names,
167
174
  mid_block_tensor_names=_mid_block_tensor_names,
168
175
  up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names,
@@ -249,6 +256,7 @@ class Diffusion(nn.Module):
249
256
  qkv_use_bias=False,
250
257
  output_proj_use_bias=True,
251
258
  enable_kv_cache=False,
259
+ qkv_fused_interleaved=False,
252
260
  )
253
261
 
254
262
  # Down encoders.
@@ -280,7 +288,7 @@ class Diffusion(nn.Module):
280
288
  stride=2,
281
289
  padding=config.downsample_padding,
282
290
  ),
283
- transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
291
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
284
292
  attention_block_config=unet_cfg.AttentionBlock2DConfig(
285
293
  dim=output_channel,
286
294
  attention_batch_size=config.transformer_batch_size,
@@ -340,7 +348,7 @@ class Diffusion(nn.Module):
340
348
  ),
341
349
  num_layers=config.mid_block_layers,
342
350
  time_embedding_channels=config.time_embedding_blocks_dim,
343
- transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
351
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
344
352
  attention_block_config=unet_cfg.AttentionBlock2DConfig(
345
353
  dim=mid_block_channels,
346
354
  attention_batch_size=config.transformer_batch_size,
@@ -401,7 +409,7 @@ class Diffusion(nn.Module):
401
409
  mode=unet_cfg.SamplingType.NEAREST,
402
410
  scale_factor=2,
403
411
  ),
404
- transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
412
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
405
413
  attention_block_config=unet_cfg.AttentionBlock2DConfig(
406
414
  dim=output_channel,
407
415
  attention_batch_size=config.transformer_batch_size,
@@ -167,7 +167,7 @@ def run_tflite_pipeline(
167
167
  if input_image:
168
168
  if not hasattr(model, 'encoder'):
169
169
  raise AttributeError(
170
- 'Stable Diffusion must be initilaized with encoder to accept input_image.'
170
+ 'Stable Diffusion must be initialized with encoder to accept input_image.'
171
171
  )
172
172
  input_image = input_image.resize((width, height))
173
173
  input_image_np = np.array(input_image).astype(np.float32)
@@ -27,6 +27,8 @@ import ai_edge_torch.generative.layers.model_config as cfg
27
27
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
28
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
29
29
 
30
+ BATCH_SIZE = 1
31
+
30
32
 
31
33
  class EncoderDecoderBlock(nn.Module):
32
34
 
@@ -44,6 +46,7 @@ class EncoderDecoderBlock(nn.Module):
44
46
 
45
47
  super().__init__()
46
48
  self.atten_func = T5Attention(
49
+ BATCH_SIZE,
47
50
  config.embedding_dim,
48
51
  config.attn_config,
49
52
  config.pre_attention_norm_config,
@@ -54,6 +57,7 @@ class EncoderDecoderBlock(nn.Module):
54
57
  # For a decoder, we add a cross attention.
55
58
  if config.is_decoder:
56
59
  self.cross_atten_func = T5Attention(
60
+ BATCH_SIZE,
57
61
  config.embedding_dim,
58
62
  config.attn_config,
59
63
  config.pre_attention_norm_config,
@@ -127,6 +131,7 @@ class T5Attention(CrossAttention):
127
131
 
128
132
  def __init__(
129
133
  self,
134
+ batch: int,
130
135
  dim: int,
131
136
  config: cfg.AttentionConfig,
132
137
  norm_config: cfg.NormalizationConfig,
@@ -144,7 +149,7 @@ class T5Attention(CrossAttention):
144
149
  enable_hlfb (bool): whether hlfb is enabled or not.
145
150
  has_relative_attention_bias (bool): whether we compute relative bias.
146
151
  """
147
- super().__init__(dim, dim, config, kv_cache_max, enable_hlfb)
152
+ super().__init__(batch, dim, dim, config, kv_cache_max, enable_hlfb)
148
153
  self.pre_atten_norm = builder.build_norm(dim, norm_config)
149
154
 
150
155
  self.has_relative_attention_bias = has_relative_attention_bias
@@ -40,7 +40,7 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
40
40
  if self.is_zero_tensor_node(source):
41
41
  # Remove the mark_tensor call on the mask input by
42
42
  # replacing the target with an identity function.
43
- node.target = lambda *args, **kwargs: args[0]
43
+ node.target = lambda *args, **kwargs: torch.zeros_like(args[0])
44
44
 
45
45
  exported_program.graph_module.graph.lint()
46
46
  exported_program.graph_module.recompile()
@@ -68,6 +68,10 @@ class AttentionConfig:
68
68
  qkv_transpose_before_split: bool = False
69
69
  # Whether to use bias with Query, Key, and Value projection.
70
70
  qkv_use_bias: bool = False
71
+ # Whether the fused q, k, v projection weights interleaves q, k, v heads.
72
+ # If True, the projection weights are in format [q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]
73
+ # If False, the projection weights are in format [q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]
74
+ qkv_fused_interleaved: bool = True
71
75
  # Whether to use bias with attention output projection.
72
76
  output_proj_use_bias: bool = False
73
77
  enable_kv_cache: bool = True
@@ -272,7 +272,7 @@ class TransformerBlock2D(nn.Module):
272
272
 
273
273
  """
274
274
 
275
- def __init__(self, config: unet_cfg.TransformerBlock2Dconfig):
275
+ def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
276
276
  """Initialize an instance of the TransformerBlock2D.
277
277
 
278
278
  Args:
@@ -85,7 +85,7 @@ class FeedForwardBlock2DConfig:
85
85
 
86
86
 
87
87
  @dataclass
88
- class TransformerBlock2Dconfig:
88
+ class TransformerBlock2DConfig:
89
89
  pre_conv_normalization_config: layers_cfg.NormalizationConfig
90
90
  attention_block_config: AttentionBlock2DConfig
91
91
  cross_attention_block_config: CrossAttentionBlock2DConfig
@@ -108,7 +108,7 @@ class UpDecoderBlock2DConfig:
108
108
  # Optional sampling config if add_upsample is True.
109
109
  sampling_config: Optional[UpSamplingConfig] = None
110
110
  # Optional config of transformer blocks interleaved with residual blocks
111
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
111
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
112
112
  # Optional dimension of context tensor if context tensor is given as input.
113
113
  context_dim: Optional[int] = None
114
114
 
@@ -131,7 +131,7 @@ class SkipUpDecoderBlock2DConfig:
131
131
  # Optional sampling config if add_upsample is True.
132
132
  sampling_config: Optional[UpSamplingConfig] = None
133
133
  # Optional config of transformer blocks interleaved with residual blocks
134
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
134
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
135
135
  # Optional dimension of context tensor if context tensor is given as input.
136
136
  context_dim: Optional[int] = None
137
137
 
@@ -152,7 +152,7 @@ class DownEncoderBlock2DConfig:
152
152
  # Optional sampling config if add_upsample is True.
153
153
  sampling_config: Optional[DownSamplingConfig] = None
154
154
  # Optional config of transformer blocks interleaved with residual blocks
155
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
155
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
156
156
  # Optional dimension of context tensor if context tensor is given as input.
157
157
  context_dim: Optional[int] = None
158
158
 
@@ -168,7 +168,7 @@ class MidBlock2DConfig:
168
168
  # Optional config of attention blocks interleaved with residual blocks
169
169
  attention_block_config: Optional[AttentionBlock2DConfig] = None
170
170
  # Optional config of transformer blocks interleaved with residual blocks
171
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
171
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
172
172
  # Optional dimension of context tensor if context tensor is given as input.
173
173
  context_dim: Optional[int] = None
174
174
 
@@ -317,9 +317,12 @@ class ModelLoader:
317
317
  k: torch.Tensor,
318
318
  v: torch.Tensor,
319
319
  ) -> torch.Tensor:
320
- q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
321
- qs = torch.split(q, config.head_dim * q_per_kv)
322
- ks = torch.split(k, config.head_dim)
323
- vs = torch.split(v, config.head_dim)
324
- cycled = [t for group in zip(qs, ks, vs) for t in group]
325
- return torch.cat(cycled)
320
+ if config.attn_config.qkv_fused_interleaved:
321
+ q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
322
+ qs = torch.split(q, config.head_dim * q_per_kv)
323
+ ks = torch.split(k, config.head_dim)
324
+ vs = torch.split(v, config.head_dim)
325
+ cycled = [t for group in zip(qs, ks, vs) for t in group]
326
+ return torch.cat(cycled)
327
+ else:
328
+ return torch.cat([q, k, v], dim=0)