ai-edge-torch-nightly 0.3.0.dev20240923__py3-none-any.whl → 0.3.0.dev20240925__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. ai_edge_torch/generative/examples/openelm/openelm.py +1 -3
  2. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  3. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  4. ai_edge_torch/generative/examples/phi/verify.py +1 -1
  5. ai_edge_torch/generative/examples/phi/verify_phi3.py +68 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/clip.py +52 -1
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +56 -0
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +69 -1
  9. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  10. ai_edge_torch/generative/examples/test_models/toy_model.py +2 -31
  11. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +2 -56
  12. ai_edge_torch/generative/layers/builder.py +25 -24
  13. ai_edge_torch/generative/layers/model_config.py +3 -3
  14. ai_edge_torch/generative/layers/normalization.py +14 -3
  15. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  16. ai_edge_torch/generative/test/test_model_conversion_large.py +119 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/METADATA +2 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/RECORD +22 -18
  20. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/top_level.txt +0 -0
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
161
161
  ),
162
162
  ff_config=cfg.FeedForwardConfig(
163
163
  type=cfg.FeedForwardType.SEQUENTIAL,
164
- activation=cfg.ActivationConfig(
165
- cfg.ActivationType.SILU_GLU, gate_is_front=True
166
- ),
164
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
167
165
  intermediate_size=get_intermediate_size(idx),
168
166
  pre_ff_norm_config=norm_config,
169
167
  ),
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi3.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
+
18
+ import math
19
+ from typing import Tuple
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.gate_up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
34
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
35
+ pre_attn_norm="model.layers.{}.input_layernorm",
36
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
37
+ embedding="model.embed_tokens",
38
+ final_norm="model.norm",
39
+ lm_head="lm_head",
40
+ )
41
+
42
+ # max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
43
+ ROPE_SCALE_FACTOR = 32
44
+
45
+ # ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
46
+ # https://github.com/microsoft/LongRoPE, these values had been searched with
47
+ # min=1.0, step-0.01 to optimize the errors of sample dataset.
48
+ ROPE_SHORT_FACTOR = [
49
+ 1.0,
50
+ 1.0199999809265137,
51
+ 1.0299999713897705,
52
+ 1.0299999713897705,
53
+ 1.0499999523162842,
54
+ 1.0499999523162842,
55
+ 1.0499999523162842,
56
+ 1.0499999523162842,
57
+ 1.0499999523162842,
58
+ 1.0699999332427979,
59
+ 1.0999999046325684,
60
+ 1.1099998950958252,
61
+ 1.1599998474121094,
62
+ 1.1599998474121094,
63
+ 1.1699998378753662,
64
+ 1.2899998426437378,
65
+ 1.339999794960022,
66
+ 1.679999828338623,
67
+ 1.7899998426437378,
68
+ 1.8199998140335083,
69
+ 1.8499997854232788,
70
+ 1.8799997568130493,
71
+ 1.9099997282028198,
72
+ 1.9399996995925903,
73
+ 1.9899996519088745,
74
+ 2.0199997425079346,
75
+ 2.0199997425079346,
76
+ 2.0199997425079346,
77
+ 2.0199997425079346,
78
+ 2.0199997425079346,
79
+ 2.0199997425079346,
80
+ 2.0299997329711914,
81
+ 2.0299997329711914,
82
+ 2.0299997329711914,
83
+ 2.0299997329711914,
84
+ 2.0299997329711914,
85
+ 2.0299997329711914,
86
+ 2.0299997329711914,
87
+ 2.0299997329711914,
88
+ 2.0299997329711914,
89
+ 2.0799996852874756,
90
+ 2.0899996757507324,
91
+ 2.189999580383301,
92
+ 2.2199995517730713,
93
+ 2.5899994373321533,
94
+ 2.729999542236328,
95
+ 2.749999523162842,
96
+ 2.8399994373321533,
97
+ ]
98
+
99
+
100
+ def build_rope_cache(
101
+ size: int,
102
+ dim: int,
103
+ base: int = 10000,
104
+ condense_ratio: int = 1,
105
+ dtype: torch.dtype = torch.float32,
106
+ device: torch.device = None,
107
+ theta_factors: torch.Tensor = None,
108
+ scale: float = 1.0,
109
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
+
112
+ It's a modified version of attn_utils.build_rope_cache with additional
113
+ arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
114
+ Cos values with scaling factors for quick lookup during the inference.
115
+
116
+ Args:
117
+ size (int): The size of the built cache.
118
+ dim (int): Each sequence's dimmension.
119
+ base (int, optional): Rope base value. Defaults to 10000.
120
+ condense_ratio (int, optional): The ratio by which sequence indicies are
121
+ condensed. Defaults to 1.
122
+ dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
+ torch.float32.
124
+ device (torch.device, optional): Output tensor's data type. Defaults to
125
+ None in which case "cpu" is used.
126
+ theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
+ scale the theta values. Defaults to None.
128
+ scale (float, optional): A float used to scale the rope values. Defaults
129
+ to 1.0.
130
+
131
+ Returns:
132
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
+ """
134
+ if device is None:
135
+ device = torch.device('cpu')
136
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
+ if theta_factors is not None:
138
+ theta = theta / theta_factors
139
+ seq_idx = torch.arange(size) / condense_ratio
140
+ idx_theta = torch.outer(seq_idx, theta)
141
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
142
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
143
+ return cos, sin
144
+
145
+
146
+ class Phi3_5Mini(nn.Module):
147
+ """A Phi-3.5 model built from the Edge Generative API layers."""
148
+
149
+ def __init__(self, config: cfg.ModelConfig):
150
+ super().__init__()
151
+
152
+ # Construct model layers.
153
+ self.lm_head = nn.Linear(
154
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
155
+ )
156
+ self.tok_embedding = nn.Embedding(
157
+ config.vocab_size, config.embedding_dim, padding_idx=0
158
+ )
159
+ # Phi-3.5 has only one block config.
160
+ block_config = config.block_config(0)
161
+ self.transformer_blocks = nn.ModuleList(
162
+ attention.TransformerBlock(block_config, config)
163
+ for _ in range(config.num_layers)
164
+ )
165
+ self.final_norm = builder.build_norm(
166
+ config.embedding_dim,
167
+ config.final_norm_config,
168
+ )
169
+ attn_config = block_config.attn_config
170
+ self.rope_cache = build_rope_cache(
171
+ size=config.kv_cache_max,
172
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
+ base=10_000,
174
+ condense_ratio=1,
175
+ dtype=torch.float32,
176
+ device=torch.device("cpu"),
177
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
178
+ scale=math.sqrt(
179
+ 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
180
+ ),
181
+ )
182
+ self.mask_cache = attn_utils.build_causal_mask_cache(
183
+ size=config.kv_cache_max,
184
+ dtype=torch.float32,
185
+ device=torch.device("cpu"),
186
+ )
187
+ self.config = config
188
+
189
+ @torch.inference_mode
190
+ def forward(
191
+ self,
192
+ tokens: torch.Tensor,
193
+ input_pos: torch.Tensor,
194
+ kv_cache: kv_utils.KVCache,
195
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
196
+ _, seq_len = tokens.size()
197
+ assert self.config.max_seq_len >= seq_len, (
198
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
199
+ f" {self.config.max_seq_len}"
200
+ )
201
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
202
+ "The number of transformer blocks and the number of KV cache entries"
203
+ " must be the same."
204
+ )
205
+
206
+ cos, sin = self.rope_cache
207
+ cos = cos.index_select(0, input_pos)
208
+ sin = sin.index_select(0, input_pos)
209
+ mask = self.mask_cache.index_select(2, input_pos)
210
+ mask = mask[:, :, :, : self.config.kv_cache_max]
211
+
212
+ x = self.tok_embedding(tokens)
213
+
214
+ updated_kv_entires = []
215
+ for i, block in enumerate(self.transformer_blocks):
216
+ kv_entry = kv_cache.caches[i] if kv_cache else None
217
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
218
+ if kv_entry:
219
+ updated_kv_entires.append(kv_entry)
220
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
221
+
222
+ x = self.final_norm(x)
223
+ logits = self.lm_head(x) # (b, t, vocab_size)
224
+ return {"logits": logits, "kv_cache": updated_kv_cache}
225
+
226
+
227
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
+ """Returns the model config for a Phi-3.5 model.
229
+
230
+ Args:
231
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
232
+ is 1024.
233
+
234
+ Returns:
235
+ The model config for a Phi-2 model.
236
+ """
237
+ attn_config = cfg.AttentionConfig(
238
+ num_heads=32,
239
+ head_dim=96,
240
+ num_query_groups=32,
241
+ rotary_percentage=1.0,
242
+ qkv_transpose_before_split=True,
243
+ )
244
+ ff_config = cfg.FeedForwardConfig(
245
+ type=cfg.FeedForwardType.SEQUENTIAL,
246
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
247
+ intermediate_size=8192,
248
+ )
249
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
250
+ block_config = cfg.TransformerBlockConfig(
251
+ attn_config=attn_config,
252
+ ff_config=ff_config,
253
+ pre_attention_norm_config=norm_config,
254
+ post_attention_norm_config=norm_config,
255
+ )
256
+ config = cfg.ModelConfig(
257
+ vocab_size=32064,
258
+ num_layers=32,
259
+ max_seq_len=4096,
260
+ kv_cache_max_len=kv_cache_max_len,
261
+ embedding_dim=3072,
262
+ block_configs=block_config,
263
+ final_norm_config=norm_config,
264
+ enable_hlfb=True,
265
+ )
266
+ return config
267
+
268
+
269
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
270
+ config = get_model_config(kv_cache_max_len)
271
+ config.vocab_size = 128
272
+ config.num_layers = 2
273
+ config.max_seq_len = 2 * kv_cache_max_len
274
+ # Phi-3.5 has only one block config.
275
+ config.block_config(0).ff_config.intermediate_size = 128
276
+ return config
277
+
278
+
279
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
280
+ """Instantiates the model instance and load checkpoint if provided."""
281
+ config = get_model_config(**kwargs)
282
+ model = Phi3_5Mini(config)
283
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
284
+ loader.load(model)
285
+ model.eval()
286
+ return model
@@ -27,13 +27,13 @@ _PROMPTS = flags.DEFINE_multi_string(
27
27
  "Instruct: Write an email about the weather Output:",
28
28
  "The input prompts to generate answers.",
29
29
  )
30
-
31
30
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
31
  "max_new_tokens",
33
32
  30,
34
33
  "The maximum size of the generated tokens.",
35
34
  )
36
35
 
36
+
37
37
  def main(_):
38
38
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
39
  verifier.log_msg("Loading the original model from", checkpoint)
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-3.5 model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.phi import phi3
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Instruct: Write an email about the weather Output:",
29
+ "The input prompts to generate answers.",
30
+ )
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
36
+
37
+
38
+ def main(_):
39
+ checkpoint = "microsoft/Phi-3.5-mini-instruct"
40
+ verifier.log_msg("Loading the original model from", checkpoint)
41
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
+ wrapper_model = verifier.ModelWrapper(
44
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
+ hf_generation_config=generation_config,
46
+ )
47
+
48
+ # Locate the cached dir.
49
+ cached_config_file = transformers.utils.cached_file(
50
+ checkpoint, transformers.utils.CONFIG_NAME
51
+ )
52
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
54
+ reauthored_model = phi3.build_model(reauthored_checkpoint)
55
+
56
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
57
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
+
59
+ verifier.verify_reauthored_model(
60
+ original_model=wrapper_model,
61
+ reauthored_model=reauthored_model,
62
+ tokenizer=tokenizer,
63
+ generate_prompts=_PROMPTS.value,
64
+ )
65
+
66
+
67
+ if __name__ == "__main__":
68
+ app.run(main)
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
48
 
49
49
 
50
50
  class CLIP(nn.Module):
51
- """CLIP text encoder
51
+ """CLIP text encoder.
52
52
 
53
53
  For details, see https://arxiv.org/abs/2103.00020
54
54
  """
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
86
86
 
87
87
 
88
88
  def get_model_config() -> cfg.ModelConfig:
89
+ """Get configs for the CLIP of Stable Diffusion v1.5."""
89
90
  max_seq_len = 77
90
91
  vocab_size = 49408
91
92
  num_layers = 12
@@ -132,3 +133,53 @@ def get_model_config() -> cfg.ModelConfig:
132
133
  )
133
134
 
134
135
  return config
136
+
137
+
138
+ def get_fake_model_config() -> cfg.ModelConfig:
139
+ """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
140
+ max_seq_len = 6
141
+ vocab_size = 100
142
+ num_layers = 2
143
+ num_heads = 12
144
+ num_query_groups = 12
145
+ embedding_dim = 24
146
+
147
+ attn_config = cfg.AttentionConfig(
148
+ num_heads=num_heads,
149
+ head_dim=embedding_dim // num_heads,
150
+ num_query_groups=num_query_groups,
151
+ rotary_percentage=0.0,
152
+ qkv_use_bias=True,
153
+ qkv_transpose_before_split=True,
154
+ qkv_fused_interleaved=False,
155
+ output_proj_use_bias=True,
156
+ enable_kv_cache=False,
157
+ )
158
+
159
+ ff_config = cfg.FeedForwardConfig(
160
+ type=cfg.FeedForwardType.SEQUENTIAL,
161
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
162
+ intermediate_size=embedding_dim * 4,
163
+ use_bias=True,
164
+ )
165
+
166
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
167
+
168
+ block_config = cfg.TransformerBlockConfig(
169
+ attn_config=attn_config,
170
+ ff_config=ff_config,
171
+ pre_attention_norm_config=norm_config,
172
+ post_attention_norm_config=norm_config,
173
+ )
174
+
175
+ config = cfg.ModelConfig(
176
+ vocab_size=vocab_size,
177
+ num_layers=num_layers,
178
+ max_seq_len=max_seq_len,
179
+ embedding_dim=embedding_dim,
180
+ block_configs=block_config,
181
+ final_norm_config=norm_config,
182
+ enable_hlfb=True,
183
+ )
184
+
185
+ return config
@@ -324,3 +324,59 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
324
324
  mid_block_config=mid_block_config,
325
325
  )
326
326
  return config
327
+
328
+
329
+ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
330
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
331
+ in_channels = 3
332
+ latent_channels = 4
333
+ out_channels = 3
334
+ block_out_channels = [2, 4]
335
+ scaling_factor = 0.18215
336
+ layers_per_block = 2
337
+
338
+ norm_config = layers_cfg.NormalizationConfig(
339
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
340
+ )
341
+
342
+ att_config = unet_cfg.AttentionBlock2DConfig(
343
+ dim=block_out_channels[-1],
344
+ normalization_config=norm_config,
345
+ attention_config=layers_cfg.AttentionConfig(
346
+ num_heads=1,
347
+ head_dim=block_out_channels[-1],
348
+ num_query_groups=1,
349
+ qkv_use_bias=True,
350
+ output_proj_use_bias=True,
351
+ enable_kv_cache=False,
352
+ qkv_transpose_before_split=True,
353
+ qkv_fused_interleaved=False,
354
+ rotary_percentage=0.0,
355
+ ),
356
+ enable_hlfb=False,
357
+ )
358
+
359
+ mid_block_config = unet_cfg.MidBlock2DConfig(
360
+ in_channels=block_out_channels[-1],
361
+ normalization_config=norm_config,
362
+ activation_config=layers_cfg.ActivationConfig(
363
+ layers_cfg.ActivationType.SILU
364
+ ),
365
+ num_layers=1,
366
+ attention_block_config=att_config,
367
+ )
368
+
369
+ config = unet_cfg.AutoEncoderConfig(
370
+ in_channels=in_channels,
371
+ latent_channels=latent_channels,
372
+ out_channels=out_channels,
373
+ activation_config=layers_cfg.ActivationConfig(
374
+ layers_cfg.ActivationType.SILU
375
+ ),
376
+ block_out_channels=block_out_channels,
377
+ scaling_factor=scaling_factor,
378
+ layers_per_block=layers_per_block,
379
+ normalization_config=norm_config,
380
+ mid_block_config=mid_block_config,
381
+ )
382
+ return config
@@ -603,7 +603,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
603
603
  # Transformer configs.
604
604
  transformer_num_attention_heads = 8
605
605
  transformer_batch_size = batch_size
606
- transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
606
+ transformer_cross_attention_dim = 768 # Embedding from CLIP model
607
607
  transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
608
608
  layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
609
609
  )
@@ -645,3 +645,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
645
645
  final_norm_config=final_norm_config,
646
646
  final_activation_type=final_activation_type,
647
647
  )
648
+
649
+
650
+ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
651
+ """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
652
+
653
+ Args:
654
+ batch_size (int): the batch size of input.
655
+
656
+ Retruns:
657
+ The configuration of diffusion model of Stable Diffusion v1.5.
658
+ """
659
+ in_channels = 4
660
+ out_channels = 4
661
+ block_out_channels = [2, 4, 8, 8]
662
+ layers_per_block = 1
663
+ downsample_padding = 1
664
+
665
+ # Residual configs.
666
+ residual_norm_config = layers_cfg.NormalizationConfig(
667
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
668
+ )
669
+ residual_activation_type = layers_cfg.ActivationType.SILU
670
+
671
+ # Transformer configs.
672
+ transformer_num_attention_heads = 1
673
+ transformer_batch_size = batch_size
674
+ transformer_cross_attention_dim = 4 # Embedding from CLIP model
675
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
676
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
677
+ )
678
+ transformer_norm_config = layers_cfg.NormalizationConfig(
679
+ layers_cfg.NormalizationType.LAYER_NORM
680
+ )
681
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
682
+
683
+ # Time embedding configs.
684
+ time_embedding_dim = 2
685
+ time_embedding_blocks_dim = 4
686
+
687
+ # Mid block configs.
688
+ mid_block_layers = 1
689
+
690
+ # Finaly layer configs.
691
+ final_norm_config = layers_cfg.NormalizationConfig(
692
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
693
+ )
694
+ final_activation_type = layers_cfg.ActivationType.SILU
695
+
696
+ return unet_cfg.DiffusionModelConfig(
697
+ in_channels=in_channels,
698
+ out_channels=out_channels,
699
+ block_out_channels=block_out_channels,
700
+ layers_per_block=layers_per_block,
701
+ downsample_padding=downsample_padding,
702
+ residual_norm_config=residual_norm_config,
703
+ residual_activation_type=residual_activation_type,
704
+ transformer_batch_size=transformer_batch_size,
705
+ transformer_num_attention_heads=transformer_num_attention_heads,
706
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
707
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
708
+ transformer_norm_config=transformer_norm_config,
709
+ transformer_ff_activation_type=transformer_ff_activation_type,
710
+ mid_block_layers=mid_block_layers,
711
+ time_embedding_dim=time_embedding_dim,
712
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
713
+ final_norm_config=final_norm_config,
714
+ final_activation_type=final_activation_type,
715
+ )
@@ -0,0 +1,105 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from absl import app
17
+ import ai_edge_torch
18
+ from ai_edge_torch import lowertools
19
+ from ai_edge_torch.generative.examples.test_models import toy_model
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import torch
23
+
24
+ KV_CACHE_MAX_LEN = 100
25
+
26
+
27
+ def convert_toy_model(_) -> None:
28
+ """Converts a toy model to tflite."""
29
+ model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
30
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
31
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
32
+ print('running an inference')
33
+ print(
34
+ model.forward(
35
+ idx,
36
+ input_pos,
37
+ )
38
+ )
39
+
40
+ # Convert model to tflite.
41
+ print('converting model to tflite')
42
+ edge_model = ai_edge_torch.convert(
43
+ model,
44
+ (
45
+ idx,
46
+ input_pos,
47
+ ),
48
+ )
49
+ edge_model.export('/tmp/toy_model.tflite')
50
+
51
+
52
+ def _export_stablehlo_mlir(model, args):
53
+ ep = torch.export.export(model, args)
54
+ return lowertools.exported_program_to_mlir_text(ep)
55
+
56
+
57
+ def convert_toy_model_with_kv_cache(_) -> None:
58
+ """Converts a toy model with kv cache to tflite."""
59
+ dump_mlir = False
60
+
61
+ config = toy_model_with_kv_cache.get_model_config()
62
+ model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
63
+ model.eval()
64
+ print('running an inference')
65
+ kv = kv_utils.KVCache.from_model_config(config)
66
+
67
+ tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
68
+ decode_token, decode_input_pos = (
69
+ toy_model_with_kv_cache.get_sample_decode_inputs()
70
+ )
71
+ print(model.forward(tokens, input_pos, kv))
72
+
73
+ if dump_mlir:
74
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
75
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
76
+ f.write(mlir_text)
77
+
78
+ # Convert model to tflite with 2 signatures (prefill + decode).
79
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
80
+ edge_model = (
81
+ ai_edge_torch.signature(
82
+ 'prefill',
83
+ model,
84
+ sample_kwargs={
85
+ 'tokens': tokens,
86
+ 'input_pos': input_pos,
87
+ 'kv_cache': kv,
88
+ },
89
+ )
90
+ .signature(
91
+ 'decode',
92
+ model,
93
+ sample_kwargs={
94
+ 'tokens': decode_token,
95
+ 'input_pos': decode_input_pos,
96
+ 'kv_cache': kv,
97
+ },
98
+ )
99
+ .convert()
100
+ )
101
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
102
+
103
+
104
+ if __name__ == '__main__':
105
+ app.run(convert_toy_model)
@@ -15,13 +15,12 @@
15
15
  # A toy example which has a single-layer transformer block.
16
16
  from typing import Tuple
17
17
 
18
- import ai_edge_torch
18
+ from ai_edge_torch.generative.layers import builder
19
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
- import ai_edge_torch.generative.layers.builder as builder
22
21
  import ai_edge_torch.generative.layers.model_config as cfg
23
22
  import torch
24
- import torch.nn as nn
23
+ from torch import nn
25
24
 
26
25
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
27
26
  KV_CACHE_MAX_LEN = 100
@@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig:
149
148
  final_norm_config=norm_config,
150
149
  )
151
150
  return config
152
-
153
-
154
- def define_and_run() -> None:
155
- model = ToySingleLayerModel(get_model_config())
156
- idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
157
- input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
158
- print('running an inference')
159
- print(
160
- model.forward(
161
- idx,
162
- input_pos,
163
- )
164
- )
165
-
166
- # Convert model to tflite.
167
- print('converting model to tflite')
168
- edge_model = ai_edge_torch.convert(
169
- model,
170
- (
171
- idx,
172
- input_pos,
173
- ),
174
- )
175
- edge_model.export('/tmp/toy_model.tflite')
176
-
177
-
178
- if __name__ == '__main__':
179
- define_and_run()
@@ -17,15 +17,14 @@
17
17
 
18
18
  from typing import Tuple
19
19
 
20
- import ai_edge_torch
21
- from ai_edge_torch import lowertools
20
+ from absl import app
22
21
  from ai_edge_torch.generative.layers import attention
23
22
  from ai_edge_torch.generative.layers import builder
24
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import torch
28
- import torch.nn as nn
27
+ from torch import nn
29
28
 
30
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
30
 
@@ -87,11 +86,6 @@ class ToyModelWithKVCache(torch.nn.Module):
87
86
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
88
87
 
89
88
 
90
- def _export_stablehlo_mlir(model, args):
91
- ep = torch.export.export(model, args)
92
- return lowertools.exported_program_to_mlir_text(ep)
93
-
94
-
95
89
  def get_model_config() -> cfg.ModelConfig:
96
90
  attn_config = cfg.AttentionConfig(
97
91
  num_heads=32,
@@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
127
  tokens = torch.tensor([[1]], dtype=torch.int)
134
128
  input_pos = torch.tensor([10])
135
129
  return tokens, input_pos
136
-
137
-
138
- def define_and_run() -> None:
139
- dump_mlir = False
140
-
141
- config = get_model_config()
142
- model = ToyModelWithExternalKV(config)
143
- model.eval()
144
- print('running an inference')
145
- kv = kv_utils.KVCache.from_model_config(config)
146
-
147
- tokens, input_pos = get_sample_prefill_inputs()
148
- decode_token, decode_input_pos = get_sample_decode_inputs()
149
- print(model.forward(tokens, input_pos, kv))
150
-
151
- if dump_mlir:
152
- mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
- with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
154
- f.write(mlir_text)
155
-
156
- # Convert model to tflite with 2 signatures (prefill + decode).
157
- print('converting toy model to tflite with 2 signatures (prefill + decode)')
158
- edge_model = (
159
- ai_edge_torch.signature(
160
- 'prefill',
161
- model,
162
- sample_kwargs={
163
- 'tokens': tokens,
164
- 'input_pos': input_pos,
165
- 'kv_cache': kv,
166
- },
167
- )
168
- .signature(
169
- 'decode',
170
- model,
171
- sample_kwargs={
172
- 'tokens': decode_token,
173
- 'input_pos': decode_input_pos,
174
- 'kv_cache': kv,
175
- },
176
- )
177
- .convert()
178
- )
179
- edge_model.export('/tmp/toy_external_kv_cache.tflite')
180
-
181
-
182
- if __name__ == '__main__':
183
- define_and_run()
@@ -23,34 +23,35 @@ from torch import nn
23
23
  import torch.nn.functional as F
24
24
 
25
25
 
26
- def build_glu(
27
- act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
- ) -> Callable[[torch.Tensor], torch.Tensor]:
29
- """Builds an activation function with GLU (Gated Linear Unit).
26
+ class GeGLU(nn.Module):
27
+ """GeGLU is an activation function which is a variant of GELU.
30
28
 
31
- If gate_is_front is True,
32
- f(x) = act(x) * y
33
- otherwise,
34
- f(x) = x * act(y),
35
- where x is the first half of the input and y is the second half of the input.
29
+ GeGLU(x) = (xW+b) * GELU(xV+c)
30
+ See: https://arxiv.org/abs/2002.05202v1
31
+ """
36
32
 
37
- Args:
38
- act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
- to the gate.
40
- gate_is_front: whether the gate is in front half of the input. Other part is
41
- the output in GLU.
33
+ def __init__(self, d_in: int, d_out: int):
34
+ super().__init__()
35
+ self.proj = nn.Linear(d_in, d_out * 2)
42
36
 
43
- Returns:
44
- A callable activation function with GLU.
37
+ def forward(self, x: torch.Tensor):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
42
+ class SwiGLU(nn.Module):
43
+ """SwiGLU is an activation function which is a variant of GLU.
44
+
45
+ SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
46
+ swish function.
47
+
48
+ SwiGLU(x) = Swish(xW+b) * (xV+c)
49
+ See: https://paperswithcode.com/method/swiglu
45
50
  """
46
51
 
47
- def _glu(x):
52
+ def forward(self, x: torch.Tensor):
48
53
  x, y = x.chunk(2, dim=-1)
49
- if gate_is_front:
50
- return act(x) * y
51
- return x * act(y)
52
-
53
- return _glu
54
+ return F.silu(x) * y
54
55
 
55
56
 
56
57
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
151
152
  # See: https://github.com/hendrycks/GELUs
152
153
  return lambda x: x * F.sigmoid(1.702 * x)
153
154
  elif config.type == cfg.ActivationType.GE_GLU:
154
- return build_glu(F.gelu, config.gate_is_front)
155
+ return GeGLU(config.dim_in, config.dim_out)
155
156
  elif config.type == cfg.ActivationType.RELU:
156
157
  return F.relu
157
158
  elif config.type == cfg.ActivationType.SILU_GLU:
158
- return build_glu(F.silu, config.gate_is_front)
159
+ return SwiGLU()
159
160
  else:
160
161
  raise ValueError("Unsupported activation type.")
@@ -118,9 +118,9 @@ class AttentionConfig:
118
118
  @dataclass
119
119
  class ActivationConfig:
120
120
  type: ActivationType = ActivationType.LINEAR
121
- # Whether to GLU gate is the front part instead of the back part of input
122
- # when ActivationType is `GE_GLU` or `SILU_GLU`.
123
- gate_is_front: bool = False
121
+ # Dimension of input and output, used in GeGLU.
122
+ dim_in: Optional[int] = None
123
+ dim_out: Optional[int] = None
124
124
 
125
125
 
126
126
  @dataclass
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
183
183
  """
184
184
  x = torch.permute(x, (0, 2, 3, 1))
185
185
 
186
+ # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
187
+ # int32 when the bug is fixed.
186
188
  builder = StableHLOCompositeBuilder(
187
- name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
189
+ name="odml.group_norm",
190
+ attr={
191
+ "num_groups": num_groups,
192
+ "epsilon": eps,
193
+ "reduction_axes": 3,
194
+ "channel_axis": 3,
195
+ },
188
196
  )
189
197
  x, w, b = builder.mark_inputs(x, w, b)
190
198
  x = torch.permute(x, (0, 3, 1, 2))
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
206
214
  """Layer Normalization with high-level function boundary enabled.
207
215
 
208
216
  Args:
209
- x (torch.Tensor): Input tensor for Layer Normalization.
217
+ x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
210
218
  w (torch.Tensor): The weight tensor for the normalization.
211
219
  b (torch.Tensor): The bias tensor for the normalization.
212
220
  eps (float): A small float value to ensure numerical stability.
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
216
224
  Returns:
217
225
  The output tensor of Layer Normalization.
218
226
  """
219
- builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
227
+ builder = StableHLOCompositeBuilder(
228
+ name="odml.group_norm",
229
+ attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
+ )
220
231
  x, w, b = builder.mark_inputs(x, w, b)
221
232
  if use_input_shape:
222
233
  normalized_shape = x.shape
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  from ai_edge_torch.generative.layers.attention import CrossAttention
19
19
  from ai_edge_torch.generative.layers.attention import SelfAttention
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
416
416
  time_emb: Optional[torch.Tensor] = None,
417
417
  context_tensor: Optional[torch.Tensor] = None,
418
418
  output_hidden_states: bool = False,
419
- ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
419
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
420
420
  """Forward function of the DownEncoderBlock2D.
421
421
 
422
422
  Args:
@@ -21,7 +21,11 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
23
  from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.phi import phi3
24
25
  from ai_edge_torch.generative.examples.smollm import smollm
26
+ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
27
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
28
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
25
29
  from ai_edge_torch.generative.layers import kv_cache
26
30
  from ai_edge_torch.generative.test import utils as test_utils
27
31
  import numpy as np
@@ -109,6 +113,17 @@ class TestModelConversion(googletest.TestCase):
109
113
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
110
114
  )
111
115
 
116
+ @googletest.skipIf(
117
+ ai_edge_config.Config.use_torch_xla,
118
+ reason="tests with custom ops are not supported on oss",
119
+ )
120
+ def test_phi3(self):
121
+ config = phi3.get_fake_model_config()
122
+ pytorch_model = phi3.Phi3_5Mini(config).eval()
123
+ self._test_model(
124
+ config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
125
+ )
126
+
112
127
  @googletest.skipIf(
113
128
  ai_edge_config.Config.use_torch_xla,
114
129
  reason="tests with custom ops are not supported on oss",
@@ -127,6 +142,110 @@ class TestModelConversion(googletest.TestCase):
127
142
  pytorch_model = openelm.OpenELM(config).eval()
128
143
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
129
144
 
145
+ @googletest.skipIf(
146
+ ai_edge_config.Config.use_torch_xla,
147
+ reason="tests with custom ops are not supported on oss",
148
+ )
149
+ def test_stable_diffusion_clip(self):
150
+ config = sd_clip.get_fake_model_config()
151
+ prompt_tokens = torch.from_numpy(
152
+ np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
153
+ )
154
+
155
+ pytorch_model = sd_clip.CLIP(config).eval()
156
+ torch_output = pytorch_model(prompt_tokens)
157
+
158
+ edge_model = ai_edge_torch.signature(
159
+ "encode", pytorch_model, (prompt_tokens,)
160
+ ).convert()
161
+ edge_model.set_interpreter_builder(
162
+ self._interpreter_builder(edge_model.tflite_model())
163
+ )
164
+ edge_output = edge_model(
165
+ prompt_tokens.numpy(),
166
+ signature_name="encode",
167
+ )
168
+ self.assertTrue(
169
+ np.allclose(
170
+ edge_output,
171
+ torch_output.detach().numpy(),
172
+ atol=1e-4,
173
+ rtol=1e-5,
174
+ )
175
+ )
176
+
177
+ @googletest.skipIf(
178
+ ai_edge_config.Config.use_torch_xla,
179
+ reason="tests with custom ops are not supported on oss",
180
+ )
181
+ def test_stable_diffusion_diffusion(self):
182
+ config = sd_diffusion.get_fake_model_config(2)
183
+ latents = torch.from_numpy(
184
+ np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
185
+ )
186
+ context = torch.from_numpy(
187
+ np.random.normal(size=(2, 4, 4)).astype(np.float32)
188
+ )
189
+ time_embedding = torch.from_numpy(
190
+ np.random.normal(size=(2, 2)).astype(np.float32)
191
+ )
192
+
193
+ pytorch_model = sd_diffusion.Diffusion(config).eval()
194
+ torch_output = pytorch_model(latents, context, time_embedding)
195
+
196
+ edge_model = ai_edge_torch.signature(
197
+ "diffusion", pytorch_model, (latents, context, time_embedding)
198
+ ).convert()
199
+ edge_model.set_interpreter_builder(
200
+ self._interpreter_builder(edge_model.tflite_model())
201
+ )
202
+ edge_output = edge_model(
203
+ latents.numpy(),
204
+ context.numpy(),
205
+ time_embedding.numpy(),
206
+ signature_name="diffusion",
207
+ )
208
+ self.assertTrue(
209
+ np.allclose(
210
+ edge_output,
211
+ torch_output.detach().numpy(),
212
+ atol=1e-4,
213
+ rtol=1e-5,
214
+ )
215
+ )
216
+
217
+ @googletest.skipIf(
218
+ ai_edge_config.Config.use_torch_xla,
219
+ reason="tests with custom ops are not supported on oss",
220
+ )
221
+ def test_stable_diffusion_decoder(self):
222
+ config = sd_decoder.get_fake_model_config()
223
+ latents = torch.from_numpy(
224
+ np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
225
+ )
226
+
227
+ pytorch_model = sd_decoder.Decoder(config).eval()
228
+ torch_output = pytorch_model(latents)
229
+
230
+ edge_model = ai_edge_torch.signature(
231
+ "decode", pytorch_model, (latents,)
232
+ ).convert()
233
+ edge_model.set_interpreter_builder(
234
+ self._interpreter_builder(edge_model.tflite_model())
235
+ )
236
+ edge_output = edge_model(
237
+ latents.numpy(),
238
+ signature_name="decode",
239
+ )
240
+ self.assertTrue(
241
+ np.allclose(
242
+ edge_output,
243
+ torch_output.detach().numpy(),
244
+ atol=1e-4,
245
+ rtol=1e-5,
246
+ )
247
+ )
248
+
130
249
 
131
250
  if __name__ == "__main__":
132
251
  googletest.main()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240923"
16
+ __version__ = "0.3.0.dev20240925"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240923
3
+ Version: 0.3.0.dev20240925
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
+ Requires-Dist: ai-edge-litert-nightly
33
34
  Requires-Dist: ai-edge-quantizer-nightly
34
35
 
35
36
  Library that supports converting PyTorch models into a .tflite format, which can
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=oxtOOEY9LJkV5vRrgr1EoSjAjuetYVNq7WQqMuauRkc,706
6
+ ai_edge_torch/version.py,sha256=UXj1-90S3RDoHwYSmy9VdMC0Sm3EHt9ESLZbi3hnWus,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -48,22 +48,25 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
49
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
51
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
51
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
52
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
53
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
54
55
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
55
56
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
56
- ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
57
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=DIDzpG8DZkWDcWsAVkcxzxIC3U3352uVI3zMoYZD16U,9554
58
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=5pQ0Bt8vGl8uTpkgXvOx8G7_rju0Gi8mIEr5NtRSAbs,2145
59
+ ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=o1UTqpimkeX3MDjgdG1QTQkoZHvCEnGClA0J0WB3wJ4,2328
57
60
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
61
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
59
62
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
60
63
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
61
64
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
62
65
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
63
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
66
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=2RMi5UmfMT4Ep68ZLJsqF-fMvEumNVkIwqtsRli9HhA,6068
64
67
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
65
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
66
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
68
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ZTRD56e8MsdGPJr7vpLa4Ju_BFw_b-FUgXgd-SO5MBw,15665
69
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6FAnevL8ZfCK2YCSPivarUH0Z8wGKSmnPpJNC0OI5A8,33680
67
70
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
68
71
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
69
72
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -78,8 +81,9 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F7
78
81
  ai_edge_torch/generative/examples/t5/t5.py,sha256=OZ67knK-UB1dBjxydG-Jwkp0Z3FzOCqGPTdg5aBFu4w,21328
79
82
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
80
83
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
82
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
84
+ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
85
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LTuzres5DHmrMT6U9rCrGf6vmR9SmopmB8sO6Cd2NxQ,5255
86
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xDYTh4m3vBEb6r3_ERhmj5qILW7YdVDAnZ-fitgYONg,4450
83
87
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
88
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
85
89
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
@@ -89,15 +93,15 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
89
93
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
90
94
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
91
95
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
92
- ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
96
+ ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
93
97
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
94
98
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
95
- ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
96
- ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
99
+ ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
100
+ ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
97
101
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
98
102
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
99
103
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
100
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
104
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
101
105
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
102
106
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
103
107
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -111,7 +115,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
111
115
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
112
116
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
113
117
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
114
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
118
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IzW2HjXS2-zePZM-qEuXL4zclnGvYsNw-6tuDSeNna4,8163
115
119
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
116
120
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
117
121
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -166,8 +170,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
166
170
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
167
171
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
168
172
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
169
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/METADATA,sha256=BgwLxDJ3AOPVn0fkngAQpf3YdmShufhMt3bANFevtiQ,1859
171
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/RECORD,,
173
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
174
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/METADATA,sha256=5KsshdZ4-3X193HkoO2ukceyDEdWGvb8ZEMcw88qt7k,1897
175
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
176
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
177
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/RECORD,,