ai-edge-torch-nightly 0.2.0.dev20240710__py3-none-any.whl → 0.2.0.dev20240711__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -23,16 +23,17 @@ import ai_edge_torch.generative.layers.model_config as cfg
23
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
24
24
 
25
25
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
- ff_up_proj="layers.{}.linear_1",
27
- ff_down_proj="layers.{}.linear_2",
28
- ff_gate_proj="layers.{}.linear_1",
29
- attn_fused_qkv_proj="layers.{}.attention.in_proj",
30
- attn_output_proj="layers.{}.attention.out_proj",
31
- pre_attn_norm="layers.{}.layernorm_1",
32
- pre_ff_norm="layers.{}.layernorm_2",
33
- embedding="embedding.token_embedding",
34
- embedding_position="embedding.position_value",
35
- final_norm="layernorm",
26
+ ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
27
+ ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
28
+ attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
29
+ attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
30
+ attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
31
+ attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
32
+ pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
33
+ pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
34
+ embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
35
+ embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
36
+ final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
36
37
  lm_head=None,
37
38
  )
38
39
 
@@ -84,6 +85,7 @@ def get_model_config() -> cfg.ModelConfig:
84
85
  rotary_percentage=0.0,
85
86
  qkv_use_bias=True,
86
87
  qkv_transpose_before_split=True,
88
+ qkv_fused_interleaved=False,
87
89
  output_proj_use_bias=True,
88
90
  enable_kv_cache=False,
89
91
  )
@@ -13,8 +13,10 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import argparse
16
17
  import os
17
18
  from pathlib import Path
19
+ from typing import Optional
18
20
 
19
21
  import torch
20
22
 
@@ -24,14 +26,36 @@ import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
26
  import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
25
27
  from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
28
  import ai_edge_torch.generative.examples.stable_diffusion.util as util
27
- import ai_edge_torch.generative.utilities.loader as loading_utils
28
29
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
29
30
 
31
+ arg_parser = argparse.ArgumentParser()
32
+ arg_parser.add_argument(
33
+ '--clip_ckpt', type=str, help='Path to source CLIP model checkpoint', required=True
34
+ )
35
+ arg_parser.add_argument(
36
+ '--diffusion_ckpt',
37
+ type=str,
38
+ help='Path to source diffusion model checkpoint',
39
+ required=True,
40
+ )
41
+ arg_parser.add_argument(
42
+ '--decoder_ckpt',
43
+ type=str,
44
+ help='Path to source image decoder model checkpoint',
45
+ required=True,
46
+ )
47
+ arg_parser.add_argument(
48
+ '--output_dir',
49
+ type=str,
50
+ help='Path to the converted TF Lite directory.',
51
+ required=True,
52
+ )
53
+
30
54
 
31
55
  @torch.inference_mode
32
56
  def convert_stable_diffusion_to_tflite(
57
+ output_dir: str,
33
58
  clip_ckpt_path: str,
34
- encoder_ckpt_path: str,
35
59
  diffusion_ckpt_path: str,
36
60
  decoder_ckpt_path: str,
37
61
  image_height: int = 512,
@@ -39,23 +63,28 @@ def convert_stable_diffusion_to_tflite(
39
63
  ):
40
64
 
41
65
  clip_model = clip.CLIP(clip.get_model_config())
42
- loader = loading_utils.ModelLoader(clip_ckpt_path, clip.TENSOR_NAMES)
66
+ loader = stable_diffusion_loader.ClipModelLoader(
67
+ clip_ckpt_path,
68
+ clip.TENSOR_NAMES,
69
+ )
43
70
  loader.load(clip_model, strict=False)
44
71
 
45
- encoder = Encoder()
46
- encoder.load_state_dict(torch.load(encoder_ckpt_path))
47
-
48
72
  diffusion_model = diffusion.Diffusion(diffusion.get_model_config(2))
49
73
  diffusion_loader = stable_diffusion_loader.DiffusionModelLoader(
50
- diffusion_ckpt_path, diffusion.TENSORS_NAMES
74
+ diffusion_ckpt_path, diffusion.TENSOR_NAMES
51
75
  )
52
- diffusion_loader.load(diffusion_model)
76
+ diffusion_loader.load(diffusion_model, strict=False)
53
77
 
54
78
  decoder_model = decoder.Decoder(decoder.get_model_config())
55
79
  decoder_loader = stable_diffusion_loader.AutoEncoderModelLoader(
56
- decoder_ckpt_path, decoder.TENSORS_NAMES
80
+ decoder_ckpt_path, decoder.TENSOR_NAMES
57
81
  )
58
- decoder_loader.load(decoder_model)
82
+ decoder_loader.load(decoder_model, strict=False)
83
+
84
+ # TODO(yichunk): enable image encoder conversion
85
+ # if encoder_ckpt_path is not None:
86
+ # encoder = Encoder()
87
+ # encoder.load_state_dict(torch.load(encoder_ckpt_path))
59
88
 
60
89
  # Tensors used to trace the model graph during conversion.
61
90
  n_tokens = 77
@@ -67,50 +96,47 @@ def convert_stable_diffusion_to_tflite(
67
96
  (len_prompt, 4, image_height // 8, image_width // 8), 0, dtype=torch.float32
68
97
  )
69
98
 
70
- input_latents = encoder(input_image, noise)
99
+ input_latents = torch.zeros_like(noise)
71
100
  context_cond = clip_model(prompt_tokens)
72
101
  context_uncond = torch.zeros_like(context_cond)
73
102
  context = torch.cat([context_cond, context_uncond], axis=0)
74
103
  time_embedding = util.get_time_embedding(timestamp)
75
104
 
105
+ if not os.path.exists(output_dir):
106
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
107
+
108
+ # TODO(yichunk): convert to multi signature tflite model.
76
109
  # CLIP text encoder
77
110
  ai_edge_torch.signature('encode', clip_model, (prompt_tokens,)).convert().export(
78
- '/tmp/stable_diffusion/clip.tflite'
111
+ f'{output_dir}/clip.tflite'
79
112
  )
80
113
 
81
- # TODO(yichunk): convert to multi signature tflite model.
114
+ # TODO(yichunk): enable image encoder conversion
82
115
  # Image encoder
83
- ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
84
- '/tmp/stable_diffusion/encoder.tflite'
85
- )
116
+ # ai_edge_torch.signature('encode', encoder, (input_image, noise)).convert().export(
117
+ # f'{output_dir}/encoder.tflite'
118
+ # )
86
119
 
87
120
  # Diffusion
88
121
  ai_edge_torch.signature(
89
122
  'diffusion',
90
123
  diffusion_model,
91
124
  (torch.repeat_interleave(input_latents, 2, 0), context, time_embedding),
92
- ).convert().export('/tmp/stable_diffusion/diffusion.tflite')
125
+ ).convert().export(f'{output_dir}/diffusion.tflite')
93
126
 
94
127
  # Image decoder
95
128
  ai_edge_torch.signature('decode', decoder_model, (input_latents,)).convert().export(
96
- '/tmp/stable_diffusion/decoder.tflite'
129
+ f'{output_dir}/decoder.tflite'
97
130
  )
98
131
 
99
132
 
100
133
  if __name__ == '__main__':
134
+ args = arg_parser.parse_args()
101
135
  convert_stable_diffusion_to_tflite(
102
- clip_ckpt_path=os.path.join(
103
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/clip.pt'
104
- ),
105
- encoder_ckpt_path=os.path.join(
106
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/encoder.pt'
107
- ),
108
- diffusion_ckpt_path=os.path.join(
109
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/diffusion.pt'
110
- ),
111
- decoder_ckpt_path=os.path.join(
112
- Path.home(), 'Downloads/stable_diffusion_data/ckpt/decoder.pt'
113
- ),
136
+ output_dir=args.output_dir,
137
+ clip_ckpt_path=args.clip_ckpt,
138
+ diffusion_ckpt_path=args.diffusion_ckpt,
139
+ decoder_ckpt_path=args.decoder_ckpt,
114
140
  image_height=512,
115
141
  image_width=512,
116
142
  )
@@ -22,29 +22,31 @@ import ai_edge_torch.generative.layers.unet.blocks_2d as blocks_2d
22
22
  import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
23
23
  import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
24
24
 
25
- TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
- post_quant_conv="0",
27
- conv_in="1",
25
+ TENSOR_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
26
+ post_quant_conv="first_stage_model.post_quant_conv",
27
+ conv_in="first_stage_model.decoder.conv_in",
28
28
  mid_block_tensor_names=stable_diffusion_loader.MidBlockTensorNames(
29
29
  residual_block_tensor_names=[
30
30
  stable_diffusion_loader.ResidualBlockTensorNames(
31
- norm_1="2.groupnorm_1",
32
- norm_2="2.groupnorm_2",
33
- conv_1="2.conv_1",
34
- conv_2="2.conv_2",
31
+ norm_1="first_stage_model.decoder.mid.block_1.norm1",
32
+ norm_2="first_stage_model.decoder.mid.block_1.norm2",
33
+ conv_1="first_stage_model.decoder.mid.block_1.conv1",
34
+ conv_2="first_stage_model.decoder.mid.block_1.conv2",
35
35
  ),
36
36
  stable_diffusion_loader.ResidualBlockTensorNames(
37
- norm_1="4.groupnorm_1",
38
- norm_2="4.groupnorm_2",
39
- conv_1="4.conv_1",
40
- conv_2="4.conv_2",
37
+ norm_1="first_stage_model.decoder.mid.block_2.norm1",
38
+ norm_2="first_stage_model.decoder.mid.block_2.norm2",
39
+ conv_1="first_stage_model.decoder.mid.block_2.conv1",
40
+ conv_2="first_stage_model.decoder.mid.block_2.conv2",
41
41
  ),
42
42
  ],
43
43
  attention_block_tensor_names=[
44
44
  stable_diffusion_loader.AttentionBlockTensorNames(
45
- norm="3.groupnorm",
46
- fused_qkv_proj="3.attention.in_proj",
47
- output_proj="3.attention.out_proj",
45
+ norm="first_stage_model.decoder.mid.attn_1.norm",
46
+ q_proj="first_stage_model.decoder.mid.attn_1.q",
47
+ k_proj="first_stage_model.decoder.mid.attn_1.k",
48
+ v_proj="first_stage_model.decoder.mid.attn_1.v",
49
+ output_proj="first_stage_model.decoder.mid.attn_1.proj_out",
48
50
  )
49
51
  ],
50
52
  ),
@@ -52,99 +54,99 @@ TENSORS_NAMES = stable_diffusion_loader.AutoEncoderModelLoader.TensorNames(
52
54
  stable_diffusion_loader.UpDecoderBlockTensorNames(
53
55
  residual_block_tensor_names=[
54
56
  stable_diffusion_loader.ResidualBlockTensorNames(
55
- norm_1="5.groupnorm_1",
56
- norm_2="5.groupnorm_2",
57
- conv_1="5.conv_1",
58
- conv_2="5.conv_2",
57
+ norm_1="first_stage_model.decoder.up.3.block.0.norm1",
58
+ norm_2="first_stage_model.decoder.up.3.block.0.norm2",
59
+ conv_1="first_stage_model.decoder.up.3.block.0.conv1",
60
+ conv_2="first_stage_model.decoder.up.3.block.0.conv2",
59
61
  ),
60
62
  stable_diffusion_loader.ResidualBlockTensorNames(
61
- norm_1="6.groupnorm_1",
62
- norm_2="6.groupnorm_2",
63
- conv_1="6.conv_1",
64
- conv_2="6.conv_2",
63
+ norm_1="first_stage_model.decoder.up.3.block.1.norm1",
64
+ norm_2="first_stage_model.decoder.up.3.block.1.norm2",
65
+ conv_1="first_stage_model.decoder.up.3.block.1.conv1",
66
+ conv_2="first_stage_model.decoder.up.3.block.1.conv2",
65
67
  ),
66
68
  stable_diffusion_loader.ResidualBlockTensorNames(
67
- norm_1="7.groupnorm_1",
68
- norm_2="7.groupnorm_2",
69
- conv_1="7.conv_1",
70
- conv_2="7.conv_2",
69
+ norm_1="first_stage_model.decoder.up.3.block.2.norm1",
70
+ norm_2="first_stage_model.decoder.up.3.block.2.norm2",
71
+ conv_1="first_stage_model.decoder.up.3.block.2.conv1",
72
+ conv_2="first_stage_model.decoder.up.3.block.2.conv2",
71
73
  ),
72
74
  ],
73
- upsample_conv="9",
75
+ upsample_conv="first_stage_model.decoder.up.3.upsample.conv",
74
76
  ),
75
77
  stable_diffusion_loader.UpDecoderBlockTensorNames(
76
78
  residual_block_tensor_names=[
77
79
  stable_diffusion_loader.ResidualBlockTensorNames(
78
- norm_1="10.groupnorm_1",
79
- norm_2="10.groupnorm_2",
80
- conv_1="10.conv_1",
81
- conv_2="10.conv_2",
80
+ norm_1="first_stage_model.decoder.up.2.block.0.norm1",
81
+ norm_2="first_stage_model.decoder.up.2.block.0.norm2",
82
+ conv_1="first_stage_model.decoder.up.2.block.0.conv1",
83
+ conv_2="first_stage_model.decoder.up.2.block.0.conv2",
82
84
  ),
83
85
  stable_diffusion_loader.ResidualBlockTensorNames(
84
- norm_1="11.groupnorm_1",
85
- norm_2="11.groupnorm_2",
86
- conv_1="11.conv_1",
87
- conv_2="11.conv_2",
86
+ norm_1="first_stage_model.decoder.up.2.block.1.norm1",
87
+ norm_2="first_stage_model.decoder.up.2.block.1.norm2",
88
+ conv_1="first_stage_model.decoder.up.2.block.1.conv1",
89
+ conv_2="first_stage_model.decoder.up.2.block.1.conv2",
88
90
  ),
89
91
  stable_diffusion_loader.ResidualBlockTensorNames(
90
- norm_1="12.groupnorm_1",
91
- norm_2="12.groupnorm_2",
92
- conv_1="12.conv_1",
93
- conv_2="12.conv_2",
92
+ norm_1="first_stage_model.decoder.up.2.block.2.norm1",
93
+ norm_2="first_stage_model.decoder.up.2.block.2.norm2",
94
+ conv_1="first_stage_model.decoder.up.2.block.2.conv1",
95
+ conv_2="first_stage_model.decoder.up.2.block.2.conv2",
94
96
  ),
95
97
  ],
96
- upsample_conv="14",
98
+ upsample_conv="first_stage_model.decoder.up.2.upsample.conv",
97
99
  ),
98
100
  stable_diffusion_loader.UpDecoderBlockTensorNames(
99
101
  residual_block_tensor_names=[
100
102
  stable_diffusion_loader.ResidualBlockTensorNames(
101
- norm_1="15.groupnorm_1",
102
- norm_2="15.groupnorm_2",
103
- conv_1="15.conv_1",
104
- conv_2="15.conv_2",
105
- residual_layer="15.residual_layer",
103
+ norm_1="first_stage_model.decoder.up.1.block.0.norm1",
104
+ norm_2="first_stage_model.decoder.up.1.block.0.norm2",
105
+ conv_1="first_stage_model.decoder.up.1.block.0.conv1",
106
+ conv_2="first_stage_model.decoder.up.1.block.0.conv2",
107
+ residual_layer="first_stage_model.decoder.up.1.block.0.nin_shortcut",
106
108
  ),
107
109
  stable_diffusion_loader.ResidualBlockTensorNames(
108
- norm_1="16.groupnorm_1",
109
- norm_2="16.groupnorm_2",
110
- conv_1="16.conv_1",
111
- conv_2="16.conv_2",
110
+ norm_1="first_stage_model.decoder.up.1.block.1.norm1",
111
+ norm_2="first_stage_model.decoder.up.1.block.1.norm2",
112
+ conv_1="first_stage_model.decoder.up.1.block.1.conv1",
113
+ conv_2="first_stage_model.decoder.up.1.block.1.conv2",
112
114
  ),
113
115
  stable_diffusion_loader.ResidualBlockTensorNames(
114
- norm_1="17.groupnorm_1",
115
- norm_2="17.groupnorm_2",
116
- conv_1="17.conv_1",
117
- conv_2="17.conv_2",
116
+ norm_1="first_stage_model.decoder.up.1.block.2.norm1",
117
+ norm_2="first_stage_model.decoder.up.1.block.2.norm2",
118
+ conv_1="first_stage_model.decoder.up.1.block.2.conv1",
119
+ conv_2="first_stage_model.decoder.up.1.block.2.conv2",
118
120
  ),
119
121
  ],
120
- upsample_conv="19",
122
+ upsample_conv="first_stage_model.decoder.up.1.upsample.conv",
121
123
  ),
122
124
  stable_diffusion_loader.UpDecoderBlockTensorNames(
123
125
  residual_block_tensor_names=[
124
126
  stable_diffusion_loader.ResidualBlockTensorNames(
125
- norm_1="20.groupnorm_1",
126
- norm_2="20.groupnorm_2",
127
- conv_1="20.conv_1",
128
- conv_2="20.conv_2",
129
- residual_layer="20.residual_layer",
127
+ norm_1="first_stage_model.decoder.up.0.block.0.norm1",
128
+ norm_2="first_stage_model.decoder.up.0.block.0.norm2",
129
+ conv_1="first_stage_model.decoder.up.0.block.0.conv1",
130
+ conv_2="first_stage_model.decoder.up.0.block.0.conv2",
131
+ residual_layer="first_stage_model.decoder.up.0.block.0.nin_shortcut",
130
132
  ),
131
133
  stable_diffusion_loader.ResidualBlockTensorNames(
132
- norm_1="21.groupnorm_1",
133
- norm_2="21.groupnorm_2",
134
- conv_1="21.conv_1",
135
- conv_2="21.conv_2",
134
+ norm_1="first_stage_model.decoder.up.0.block.1.norm1",
135
+ norm_2="first_stage_model.decoder.up.0.block.1.norm2",
136
+ conv_1="first_stage_model.decoder.up.0.block.1.conv1",
137
+ conv_2="first_stage_model.decoder.up.0.block.1.conv2",
136
138
  ),
137
139
  stable_diffusion_loader.ResidualBlockTensorNames(
138
- norm_1="22.groupnorm_1",
139
- norm_2="22.groupnorm_2",
140
- conv_1="22.conv_1",
141
- conv_2="22.conv_2",
140
+ norm_1="first_stage_model.decoder.up.0.block.2.norm1",
141
+ norm_2="first_stage_model.decoder.up.0.block.2.norm2",
142
+ conv_1="first_stage_model.decoder.up.0.block.2.conv1",
143
+ conv_2="first_stage_model.decoder.up.0.block.2.conv2",
142
144
  ),
143
145
  ],
144
146
  ),
145
147
  ],
146
- final_norm="23",
147
- conv_out="25",
148
+ final_norm="first_stage_model.decoder.norm_out",
149
+ conv_out="first_stage_model.decoder.conv_out",
148
150
  )
149
151
 
150
152
 
@@ -288,6 +290,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
288
290
  output_proj_use_bias=True,
289
291
  enable_kv_cache=False,
290
292
  qkv_transpose_before_split=True,
293
+ qkv_fused_interleaved=False,
291
294
  rotary_percentage=0.0,
292
295
  ),
293
296
  )
@@ -26,12 +26,12 @@ _down_encoder_blocks_tensor_names = [
26
26
  stable_diffusion_loader.DownEncoderBlockTensorNames(
27
27
  residual_block_tensor_names=[
28
28
  stable_diffusion_loader.ResidualBlockTensorNames(
29
- norm_1=f"unet.encoders.{i*3+j+1}.0.groupnorm_feature",
30
- conv_1=f"unet.encoders.{i*3+j+1}.0.conv_feature",
31
- norm_2=f"unet.encoders.{i*3+j+1}.0.groupnorm_merged",
32
- conv_2=f"unet.encoders.{i*3+j+1}.0.conv_merged",
33
- time_embedding=f"unet.encoders.{i*3+j+1}.0.linear_time",
34
- residual_layer=f"unet.encoders.{i*3+j+1}.0.residual_layer"
29
+ norm_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.0",
30
+ conv_1=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.in_layers.2",
31
+ norm_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.0",
32
+ conv_2=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.out_layers.3",
33
+ time_embedding=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.emb_layers.1",
34
+ residual_layer=f"model.diffusion_model.input_blocks.{i*3+j+1}.0.skip_connection"
35
35
  if (i * 3 + j + 1) in [4, 7]
36
36
  else None,
37
37
  )
@@ -39,32 +39,36 @@ _down_encoder_blocks_tensor_names = [
39
39
  ],
40
40
  transformer_block_tensor_names=[
41
41
  stable_diffusion_loader.TransformerBlockTensorNames(
42
- pre_conv_norm=f"unet.encoders.{i*3+j+1}.1.groupnorm",
43
- conv_in=f"unet.encoders.{i*3+j+1}.1.conv_input",
44
- conv_out=f"unet.encoders.{i*3+j+1}.1.conv_output",
42
+ pre_conv_norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.norm",
43
+ conv_in=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_in",
44
+ conv_out=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.proj_out",
45
45
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
46
- norm=f"unet.encoders.{i*3+j+1}.1.layernorm_1",
47
- fused_qkv_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.in_proj",
48
- output_proj=f"unet.encoders.{i*3+j+1}.1.attention_1.out_proj",
46
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm1",
47
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_q",
48
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_k",
49
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_v",
50
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn1.to_out.0",
49
51
  ),
50
52
  cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
51
- norm=f"unet.encoders.{i*3+j+1}.1.layernorm_2",
52
- q_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.q_proj",
53
- k_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.k_proj",
54
- v_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.v_proj",
55
- output_proj=f"unet.encoders.{i*3+j+1}.1.attention_2.out_proj",
53
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm2",
54
+ q_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_q",
55
+ k_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_k",
56
+ v_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_v",
57
+ output_proj=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.attn2.to_out.0",
56
58
  ),
57
59
  feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
58
- norm=f"unet.encoders.{i*3+j+1}.1.layernorm_3",
59
- ge_glu=f"unet.encoders.{i*3+j+1}.1.linear_geglu_1",
60
- w2=f"unet.encoders.{i*3+j+1}.1.linear_geglu_2",
60
+ norm=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.norm3",
61
+ ge_glu=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.0.proj",
62
+ w2=f"model.diffusion_model.input_blocks.{i*3+j+1}.1.transformer_blocks.0.ff.net.2",
61
63
  ),
62
64
  )
63
65
  for j in range(2)
64
66
  ]
65
67
  if i < 3
66
68
  else None,
67
- downsample_conv=f"unet.encoders.{i*3+3}.0" if i < 3 else None,
69
+ downsample_conv=f"model.diffusion_model.input_blocks.{i*3+3}.0.op"
70
+ if i < 3
71
+ else None,
68
72
  )
69
73
  for i in range(4)
70
74
  ]
@@ -72,35 +76,37 @@ _down_encoder_blocks_tensor_names = [
72
76
  _mid_block_tensor_names = stable_diffusion_loader.MidBlockTensorNames(
73
77
  residual_block_tensor_names=[
74
78
  stable_diffusion_loader.ResidualBlockTensorNames(
75
- norm_1=f"unet.bottleneck.{i}.groupnorm_feature",
76
- conv_1=f"unet.bottleneck.{i}.conv_feature",
77
- norm_2=f"unet.bottleneck.{i}.groupnorm_merged",
78
- conv_2=f"unet.bottleneck.{i}.conv_merged",
79
- time_embedding=f"unet.bottleneck.{i}.linear_time",
79
+ norm_1=f"model.diffusion_model.middle_block.{i}.in_layers.0",
80
+ conv_1=f"model.diffusion_model.middle_block.{i}.in_layers.2",
81
+ norm_2=f"model.diffusion_model.middle_block.{i}.out_layers.0",
82
+ conv_2=f"model.diffusion_model.middle_block.{i}.out_layers.3",
83
+ time_embedding=f"model.diffusion_model.middle_block.{i}.emb_layers.1",
80
84
  )
81
85
  for i in [0, 2]
82
86
  ],
83
87
  transformer_block_tensor_names=[
84
88
  stable_diffusion_loader.TransformerBlockTensorNames(
85
- pre_conv_norm=f"unet.bottleneck.{i}.groupnorm",
86
- conv_in=f"unet.bottleneck.{i}.conv_input",
87
- conv_out=f"unet.bottleneck.{i}.conv_output",
89
+ pre_conv_norm=f"model.diffusion_model.middle_block.{i}.norm",
90
+ conv_in=f"model.diffusion_model.middle_block.{i}.proj_in",
91
+ conv_out=f"model.diffusion_model.middle_block.{i}.proj_out",
88
92
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
89
- norm=f"unet.bottleneck.{i}.layernorm_1",
90
- fused_qkv_proj=f"unet.bottleneck.{i}.attention_1.in_proj",
91
- output_proj=f"unet.bottleneck.{i}.attention_1.out_proj",
93
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm1",
94
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_q",
95
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_k",
96
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_v",
97
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn1.to_out.0",
92
98
  ),
93
99
  cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
94
- norm=f"unet.bottleneck.{i}.layernorm_2",
95
- q_proj=f"unet.bottleneck.{i}.attention_2.q_proj",
96
- k_proj=f"unet.bottleneck.{i}.attention_2.k_proj",
97
- v_proj=f"unet.bottleneck.{i}.attention_2.v_proj",
98
- output_proj=f"unet.bottleneck.{i}.attention_2.out_proj",
100
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm2",
101
+ q_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_q",
102
+ k_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_k",
103
+ v_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_v",
104
+ output_proj=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.attn2.to_out.0",
99
105
  ),
100
106
  feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
101
- norm=f"unet.bottleneck.{i}.layernorm_3",
102
- ge_glu=f"unet.bottleneck.{i}.linear_geglu_1",
103
- w2=f"unet.bottleneck.{i}.linear_geglu_2",
107
+ norm=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.norm3",
108
+ ge_glu=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.0.proj",
109
+ w2=f"model.diffusion_model.middle_block.{i}.transformer_blocks.0.ff.net.2",
104
110
  ),
105
111
  )
106
112
  for i in [1]
@@ -111,58 +117,59 @@ _up_decoder_blocks_tensor_names = [
111
117
  stable_diffusion_loader.SkipUpDecoderBlockTensorNames(
112
118
  residual_block_tensor_names=[
113
119
  stable_diffusion_loader.ResidualBlockTensorNames(
114
- norm_1=f"unet.decoders.{i*3+j}.0.groupnorm_feature",
115
- conv_1=f"unet.decoders.{i*3+j}.0.conv_feature",
116
- norm_2=f"unet.decoders.{i*3+j}.0.groupnorm_merged",
117
- conv_2=f"unet.decoders.{i*3+j}.0.conv_merged",
118
- time_embedding=f"unet.decoders.{i*3+j}.0.linear_time",
119
- residual_layer=f"unet.decoders.{i*3+j}.0.residual_layer",
120
+ norm_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.0",
121
+ conv_1=f"model.diffusion_model.output_blocks.{i*3+j}.0.in_layers.2",
122
+ norm_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.0",
123
+ conv_2=f"model.diffusion_model.output_blocks.{i*3+j}.0.out_layers.3",
124
+ time_embedding=f"model.diffusion_model.output_blocks.{i*3+j}.0.emb_layers.1",
125
+ residual_layer=f"model.diffusion_model.output_blocks.{i*3+j}.0.skip_connection",
120
126
  )
121
127
  for j in range(3)
122
128
  ],
123
129
  transformer_block_tensor_names=[
124
130
  stable_diffusion_loader.TransformerBlockTensorNames(
125
- pre_conv_norm=f"unet.decoders.{i*3+j}.1.groupnorm",
126
- conv_in=f"unet.decoders.{i*3+j}.1.conv_input",
127
- conv_out=f"unet.decoders.{i*3+j}.1.conv_output",
131
+ pre_conv_norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.norm",
132
+ conv_in=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_in",
133
+ conv_out=f"model.diffusion_model.output_blocks.{i*3+j}.1.proj_out",
128
134
  self_attention=stable_diffusion_loader.AttentionBlockTensorNames(
129
- norm=f"unet.decoders.{i*3+j}.1.layernorm_1",
130
- fused_qkv_proj=f"unet.decoders.{i*3+j}.1.attention_1.in_proj",
131
- output_proj=f"unet.decoders.{i*3+j}.1.attention_1.out_proj",
135
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm1",
136
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_q",
137
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_k",
138
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_v",
139
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn1.to_out.0",
132
140
  ),
133
141
  cross_attention=stable_diffusion_loader.CrossAttentionBlockTensorNames(
134
- norm=f"unet.decoders.{i*3+j}.1.layernorm_2",
135
- q_proj=f"unet.decoders.{i*3+j}.1.attention_2.q_proj",
136
- k_proj=f"unet.decoders.{i*3+j}.1.attention_2.k_proj",
137
- v_proj=f"unet.decoders.{i*3+j}.1.attention_2.v_proj",
138
- output_proj=f"unet.decoders.{i*3+j}.1.attention_2.out_proj",
142
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm2",
143
+ q_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_q",
144
+ k_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_k",
145
+ v_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_v",
146
+ output_proj=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.attn2.to_out.0",
139
147
  ),
140
148
  feed_forward=stable_diffusion_loader.FeedForwardBlockTensorNames(
141
- norm=f"unet.decoders.{i*3+j}.1.layernorm_3",
142
- ge_glu=f"unet.decoders.{i*3+j}.1.linear_geglu_1",
143
- w2=f"unet.decoders.{i*3+j}.1.linear_geglu_2",
149
+ norm=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.norm3",
150
+ ge_glu=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.0.proj",
151
+ w2=f"model.diffusion_model.output_blocks.{i*3+j}.1.transformer_blocks.0.ff.net.2",
144
152
  ),
145
153
  )
146
154
  for j in range(3)
147
155
  ]
148
156
  if i > 0
149
157
  else None,
150
- upsample_conv=f"unet.decoders.{i*3+2}.2.conv"
158
+ upsample_conv=f"model.diffusion_model.output_blocks.{i*3+2}.2.conv"
151
159
  if 0 < i < 3
152
- else (f"unet.decoders.2.1.conv" if i == 0 else None),
160
+ else (f"model.diffusion_model.output_blocks.2.1.conv" if i == 0 else None),
153
161
  )
154
162
  for i in range(4)
155
163
  ]
156
164
 
157
-
158
- TENSORS_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
165
+ TENSOR_NAMES = stable_diffusion_loader.DiffusionModelLoader.TensorNames(
159
166
  time_embedding=stable_diffusion_loader.TimeEmbeddingTensorNames(
160
- w1="time_embedding.linear_1",
161
- w2="time_embedding.linear_2",
167
+ w1="model.diffusion_model.time_embed.0",
168
+ w2="model.diffusion_model.time_embed.2",
162
169
  ),
163
- conv_in="unet.encoders.0.0",
164
- conv_out="final.conv",
165
- final_norm="final.groupnorm",
170
+ conv_in="model.diffusion_model.input_blocks.0.0",
171
+ conv_out="model.diffusion_model.out.2",
172
+ final_norm="model.diffusion_model.out.0",
166
173
  down_encoder_blocks_tensor_names=_down_encoder_blocks_tensor_names,
167
174
  mid_block_tensor_names=_mid_block_tensor_names,
168
175
  up_decoder_blocks_tensor_names=_up_decoder_blocks_tensor_names,
@@ -249,6 +256,7 @@ class Diffusion(nn.Module):
249
256
  qkv_use_bias=False,
250
257
  output_proj_use_bias=True,
251
258
  enable_kv_cache=False,
259
+ qkv_fused_interleaved=False,
252
260
  )
253
261
 
254
262
  # Down encoders.
@@ -280,7 +288,7 @@ class Diffusion(nn.Module):
280
288
  stride=2,
281
289
  padding=config.downsample_padding,
282
290
  ),
283
- transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
291
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
284
292
  attention_block_config=unet_cfg.AttentionBlock2DConfig(
285
293
  dim=output_channel,
286
294
  attention_batch_size=config.transformer_batch_size,
@@ -340,7 +348,7 @@ class Diffusion(nn.Module):
340
348
  ),
341
349
  num_layers=config.mid_block_layers,
342
350
  time_embedding_channels=config.time_embedding_blocks_dim,
343
- transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
351
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
344
352
  attention_block_config=unet_cfg.AttentionBlock2DConfig(
345
353
  dim=mid_block_channels,
346
354
  attention_batch_size=config.transformer_batch_size,
@@ -401,7 +409,7 @@ class Diffusion(nn.Module):
401
409
  mode=unet_cfg.SamplingType.NEAREST,
402
410
  scale_factor=2,
403
411
  ),
404
- transformer_block_config=unet_cfg.TransformerBlock2Dconfig(
412
+ transformer_block_config=unet_cfg.TransformerBlock2DConfig(
405
413
  attention_block_config=unet_cfg.AttentionBlock2DConfig(
406
414
  dim=output_channel,
407
415
  attention_batch_size=config.transformer_batch_size,
@@ -167,7 +167,7 @@ def run_tflite_pipeline(
167
167
  if input_image:
168
168
  if not hasattr(model, 'encoder'):
169
169
  raise AttributeError(
170
- 'Stable Diffusion must be initilaized with encoder to accept input_image.'
170
+ 'Stable Diffusion must be initialized with encoder to accept input_image.'
171
171
  )
172
172
  input_image = input_image.resize((width, height))
173
173
  input_image_np = np.array(input_image).astype(np.float32)
@@ -27,6 +27,8 @@ import ai_edge_torch.generative.layers.model_config as cfg
27
27
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention # NOQA
28
28
  from ai_edge_torch.generative.layers.scaled_dot_product_attention import scaled_dot_product_attention_with_hlfb # NOQA
29
29
 
30
+ BATCH_SIZE = 1
31
+
30
32
 
31
33
  class EncoderDecoderBlock(nn.Module):
32
34
 
@@ -44,6 +46,7 @@ class EncoderDecoderBlock(nn.Module):
44
46
 
45
47
  super().__init__()
46
48
  self.atten_func = T5Attention(
49
+ BATCH_SIZE,
47
50
  config.embedding_dim,
48
51
  config.attn_config,
49
52
  config.pre_attention_norm_config,
@@ -54,6 +57,7 @@ class EncoderDecoderBlock(nn.Module):
54
57
  # For a decoder, we add a cross attention.
55
58
  if config.is_decoder:
56
59
  self.cross_atten_func = T5Attention(
60
+ BATCH_SIZE,
57
61
  config.embedding_dim,
58
62
  config.attn_config,
59
63
  config.pre_attention_norm_config,
@@ -127,6 +131,7 @@ class T5Attention(CrossAttention):
127
131
 
128
132
  def __init__(
129
133
  self,
134
+ batch: int,
130
135
  dim: int,
131
136
  config: cfg.AttentionConfig,
132
137
  norm_config: cfg.NormalizationConfig,
@@ -144,7 +149,7 @@ class T5Attention(CrossAttention):
144
149
  enable_hlfb (bool): whether hlfb is enabled or not.
145
150
  has_relative_attention_bias (bool): whether we compute relative bias.
146
151
  """
147
- super().__init__(dim, dim, config, kv_cache_max, enable_hlfb)
152
+ super().__init__(batch, dim, dim, config, kv_cache_max, enable_hlfb)
148
153
  self.pre_atten_norm = builder.build_norm(dim, norm_config)
149
154
 
150
155
  self.has_relative_attention_bias = has_relative_attention_bias
@@ -68,6 +68,10 @@ class AttentionConfig:
68
68
  qkv_transpose_before_split: bool = False
69
69
  # Whether to use bias with Query, Key, and Value projection.
70
70
  qkv_use_bias: bool = False
71
+ # Whether the fused q, k, v projection weights interleaves q, k, v heads.
72
+ # If True, the projection weights are in format [q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]
73
+ # If False, the projection weights are in format [q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]
74
+ qkv_fused_interleaved: bool = True
71
75
  # Whether to use bias with attention output projection.
72
76
  output_proj_use_bias: bool = False
73
77
  enable_kv_cache: bool = True
@@ -272,7 +272,7 @@ class TransformerBlock2D(nn.Module):
272
272
 
273
273
  """
274
274
 
275
- def __init__(self, config: unet_cfg.TransformerBlock2Dconfig):
275
+ def __init__(self, config: unet_cfg.TransformerBlock2DConfig):
276
276
  """Initialize an instance of the TransformerBlock2D.
277
277
 
278
278
  Args:
@@ -85,7 +85,7 @@ class FeedForwardBlock2DConfig:
85
85
 
86
86
 
87
87
  @dataclass
88
- class TransformerBlock2Dconfig:
88
+ class TransformerBlock2DConfig:
89
89
  pre_conv_normalization_config: layers_cfg.NormalizationConfig
90
90
  attention_block_config: AttentionBlock2DConfig
91
91
  cross_attention_block_config: CrossAttentionBlock2DConfig
@@ -108,7 +108,7 @@ class UpDecoderBlock2DConfig:
108
108
  # Optional sampling config if add_upsample is True.
109
109
  sampling_config: Optional[UpSamplingConfig] = None
110
110
  # Optional config of transformer blocks interleaved with residual blocks
111
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
111
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
112
112
  # Optional dimension of context tensor if context tensor is given as input.
113
113
  context_dim: Optional[int] = None
114
114
 
@@ -131,7 +131,7 @@ class SkipUpDecoderBlock2DConfig:
131
131
  # Optional sampling config if add_upsample is True.
132
132
  sampling_config: Optional[UpSamplingConfig] = None
133
133
  # Optional config of transformer blocks interleaved with residual blocks
134
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
134
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
135
135
  # Optional dimension of context tensor if context tensor is given as input.
136
136
  context_dim: Optional[int] = None
137
137
 
@@ -152,7 +152,7 @@ class DownEncoderBlock2DConfig:
152
152
  # Optional sampling config if add_upsample is True.
153
153
  sampling_config: Optional[DownSamplingConfig] = None
154
154
  # Optional config of transformer blocks interleaved with residual blocks
155
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
155
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
156
156
  # Optional dimension of context tensor if context tensor is given as input.
157
157
  context_dim: Optional[int] = None
158
158
 
@@ -168,7 +168,7 @@ class MidBlock2DConfig:
168
168
  # Optional config of attention blocks interleaved with residual blocks
169
169
  attention_block_config: Optional[AttentionBlock2DConfig] = None
170
170
  # Optional config of transformer blocks interleaved with residual blocks
171
- transformer_block_config: Optional[TransformerBlock2Dconfig] = None
171
+ transformer_block_config: Optional[TransformerBlock2DConfig] = None
172
172
  # Optional dimension of context tensor if context tensor is given as input.
173
173
  context_dim: Optional[int] = None
174
174
 
@@ -317,9 +317,12 @@ class ModelLoader:
317
317
  k: torch.Tensor,
318
318
  v: torch.Tensor,
319
319
  ) -> torch.Tensor:
320
- q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
321
- qs = torch.split(q, config.head_dim * q_per_kv)
322
- ks = torch.split(k, config.head_dim)
323
- vs = torch.split(v, config.head_dim)
324
- cycled = [t for group in zip(qs, ks, vs) for t in group]
325
- return torch.cat(cycled)
320
+ if config.attn_config.qkv_fused_interleaved:
321
+ q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
322
+ qs = torch.split(q, config.head_dim * q_per_kv)
323
+ ks = torch.split(k, config.head_dim)
324
+ vs = torch.split(v, config.head_dim)
325
+ cycled = [t for group in zip(qs, ks, vs) for t in group]
326
+ return torch.cat(cycled)
327
+ else:
328
+ return torch.cat([q, k, v], dim=0)
@@ -37,6 +37,9 @@ class ResidualBlockTensorNames:
37
37
  class AttentionBlockTensorNames:
38
38
  norm: str = None
39
39
  fused_qkv_proj: str = None
40
+ q_proj: str = None
41
+ k_proj: str = None
42
+ v_proj: str = None
40
43
  output_proj: str = None
41
44
 
42
45
 
@@ -106,12 +109,21 @@ def _map_to_converted_state(
106
109
  state_param: str,
107
110
  converted_state: Dict[str, torch.Tensor],
108
111
  converted_state_param: str,
112
+ squeeze_dims: bool = False,
109
113
  ):
110
114
  converted_state[f"{converted_state_param}.weight"] = state.pop(
111
115
  f"{state_param}.weight"
112
116
  )
117
+ if squeeze_dims:
118
+ converted_state[f"{converted_state_param}.weight"] = torch.squeeze(
119
+ converted_state[f"{converted_state_param}.weight"]
120
+ )
113
121
  if f"{state_param}.bias" in state:
114
122
  converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
123
+ if squeeze_dims:
124
+ converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
125
+ converted_state[f"{converted_state_param}.bias"]
126
+ )
115
127
 
116
128
 
117
129
  class BaseLoader(loader.ModelLoader):
@@ -179,17 +191,65 @@ class BaseLoader(loader.ModelLoader):
179
191
  f"{converted_state_param_prefix}.norm",
180
192
  )
181
193
  attention_layer_prefix = f"{converted_state_param_prefix}.attention"
182
- _map_to_converted_state(
183
- state,
184
- tensor_names.fused_qkv_proj,
185
- converted_state,
186
- f"{attention_layer_prefix}.qkv_projection",
187
- )
194
+ if tensor_names.fused_qkv_proj is not None:
195
+ _map_to_converted_state(
196
+ state,
197
+ tensor_names.fused_qkv_proj,
198
+ converted_state,
199
+ f"{attention_layer_prefix}.qkv_projection",
200
+ )
201
+ else:
202
+ _map_to_converted_state(
203
+ state,
204
+ tensor_names.q_proj,
205
+ converted_state,
206
+ f"{attention_layer_prefix}.q_projection",
207
+ squeeze_dims=True,
208
+ )
209
+ _map_to_converted_state(
210
+ state,
211
+ tensor_names.k_proj,
212
+ converted_state,
213
+ f"{attention_layer_prefix}.k_projection",
214
+ squeeze_dims=True,
215
+ )
216
+ _map_to_converted_state(
217
+ state,
218
+ tensor_names.v_proj,
219
+ converted_state,
220
+ f"{attention_layer_prefix}.v_projection",
221
+ squeeze_dims=True,
222
+ )
223
+ converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = torch.concat(
224
+ [
225
+ converted_state[f"{attention_layer_prefix}.q_projection.weight"],
226
+ converted_state[f"{attention_layer_prefix}.k_projection.weight"],
227
+ converted_state[f"{attention_layer_prefix}.v_projection.weight"],
228
+ ],
229
+ axis=0,
230
+ )
231
+ del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
232
+ del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
233
+ del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
234
+ if config.attention_config.qkv_use_bias:
235
+ converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = torch.concat(
236
+ [
237
+ converted_state[f"{attention_layer_prefix}.q_projection.bias"],
238
+ converted_state[f"{attention_layer_prefix}.k_projection.bias"],
239
+ converted_state[f"{attention_layer_prefix}.v_projection.bias"],
240
+ ],
241
+ axis=0,
242
+ )
243
+ del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
244
+ del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
245
+ del converted_state[f"{attention_layer_prefix}.v_projection.bias"]
246
+
188
247
  _map_to_converted_state(
189
248
  state,
190
249
  tensor_names.output_proj,
191
250
  converted_state,
192
251
  f"{attention_layer_prefix}.output_projection",
252
+ squeeze_dims=True,
193
253
  )
194
254
 
195
255
  def _map_cross_attention_block(
@@ -269,7 +329,7 @@ class BaseLoader(loader.ModelLoader):
269
329
  converted_state: Dict[str, torch.Tensor],
270
330
  tensor_names: TransformerBlockTensorNames,
271
331
  converted_state_param_prefix: str,
272
- config: unet_config.TransformerBlock2Dconfig,
332
+ config: unet_config.TransformerBlock2DConfig,
273
333
  ):
274
334
  _map_to_converted_state(
275
335
  state,
@@ -482,6 +542,10 @@ class BaseLoader(loader.ModelLoader):
482
542
  )
483
543
 
484
544
 
545
+ # Alias class name for better code reading.
546
+ ClipModelLoader = BaseLoader
547
+
548
+
485
549
  class AutoEncoderModelLoader(BaseLoader):
486
550
 
487
551
  @dataclass
@@ -668,7 +732,7 @@ class DiffusionModelLoader(BaseLoader):
668
732
  stride=2,
669
733
  padding=config.downsample_padding,
670
734
  ),
671
- transformer_block_config=unet_config.TransformerBlock2Dconfig(
735
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
672
736
  attention_block_config=unet_config.AttentionBlock2DConfig(
673
737
  dim=output_channel,
674
738
  normalization_config=config.transformer_norm_config,
@@ -726,7 +790,7 @@ class DiffusionModelLoader(BaseLoader):
726
790
  ),
727
791
  num_layers=config.mid_block_layers,
728
792
  time_embedding_channels=config.time_embedding_blocks_dim,
729
- transformer_block_config=unet_config.TransformerBlock2Dconfig(
793
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
730
794
  attention_block_config=unet_config.AttentionBlock2DConfig(
731
795
  dim=mid_block_channels,
732
796
  normalization_config=config.transformer_norm_config,
@@ -789,7 +853,7 @@ class DiffusionModelLoader(BaseLoader):
789
853
  mode=unet_config.SamplingType.NEAREST,
790
854
  scale_factor=2,
791
855
  ),
792
- transformer_block_config=unet_config.TransformerBlock2Dconfig(
856
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
793
857
  attention_block_config=unet_config.AttentionBlock2DConfig(
794
858
  dim=output_channel,
795
859
  normalization_config=config.transformer_norm_config,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240710
3
+ Version: 0.2.0.dev20240711
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -43,12 +43,12 @@ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=uF1A2EX8xYie3
43
43
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
44
44
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
45
45
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
46
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
47
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
48
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
49
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
46
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=P-cUUQaQKGKV2p-7hvLJ--RpCIA7gk8WCDRgg0pNtd0,4331
47
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=XwV1z7cVkQ947k_ERftEeL8n0NUFCJAltLtqDVfzYGI,4704
48
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=xHcmOZaW7hoWlEEEqtB4FWoHMw5AsGHPHXMNiXEfviY,13814
49
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=G-MgiEM_PpegNMePBPuNQDeUfjk42EYrVZAyJHC54AY,28468
50
50
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
51
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
51
+ ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=AopJ-KE74lzq4QJUP_hYeiXvGth7uWv7nNKqkhtcoF8,8277
52
52
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
53
53
  ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=NFpOfA4KN0JpShm5QvuYbQYZ844NzexWD8nV3WjMOZM,2397
54
54
  ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
@@ -59,7 +59,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
59
59
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
60
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=7RwaZQaKhFt3zKAUbFjq95CSYhL1nd9BVSbSRNJp4-4,4529
61
61
  ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
62
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
62
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=KaGzCAViNOpJIQbRF-ItouuVPqI9nroWRRGN-KFYKZs,8357
63
63
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
64
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Sf3ZMYv-iuMRKAKLow47qth8vTF1zl6i8TxJ9uT_StU,3885
65
65
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
@@ -75,14 +75,14 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNl
75
75
  ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
76
76
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
77
77
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
78
- ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
78
+ ai_edge_torch/generative/layers/model_config.py,sha256=s6aIBib_LhjZC3p1pRxjcg3mf1BUrGqPQdsb6G83U-c,5028
79
79
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
80
80
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
81
81
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
82
82
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
83
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=T70veX57CC9uNidwzoVGzOu-CwzcYMBr1Zk_0bq5UlM,26538
84
84
  ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
85
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
85
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
86
86
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
87
  ai_edge_torch/generative/quantize/example.py,sha256=Oy-Ss1oKXMu5RVOGt8QiUwKtrHEfhbVjTXXjxPcOqDA,1536
88
88
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -97,8 +97,8 @@ ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-y
97
97
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
98
98
  ai_edge_torch/generative/test/test_quantize.py,sha256=TxZwe2cCTfwq9t2thBuYiLdp5Xu2cspCbQgziZ3Oo7k,5269
99
99
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
100
- ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
101
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
100
+ ai_edge_torch/generative/utilities/loader.py,sha256=NTaCrU2qmeJpqdAau13ZgyeOpwATqhZB68GY0LZjU6A,11438
101
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=zixjZryUaCSDKmfPkQvYwbPJhUyTmZ4AK_lWN8iFo68,33324
102
102
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
103
103
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
104
104
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
@@ -114,8 +114,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
114
114
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
115
115
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
116
116
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
117
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/METADATA,sha256=6ask_HCsla1Tzx5_ORpPGrdvtwYAwS6BB3jNV31Jo9g,1745
119
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/RECORD,,
117
+ ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
+ ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/METADATA,sha256=GftPz7zSGYCaTvO4gntWftMbj0NCSh4OXJEe1epdBCU,1745
119
+ ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
+ ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
+ ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/RECORD,,