ai-edge-torch-nightly 0.3.0.dev20240923__py3-none-any.whl → 0.3.0.dev20240925__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. ai_edge_torch/generative/examples/openelm/openelm.py +1 -3
  2. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  3. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  4. ai_edge_torch/generative/examples/phi/verify.py +1 -1
  5. ai_edge_torch/generative/examples/phi/verify_phi3.py +68 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/clip.py +52 -1
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +56 -0
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +69 -1
  9. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  10. ai_edge_torch/generative/examples/test_models/toy_model.py +2 -31
  11. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +2 -56
  12. ai_edge_torch/generative/layers/builder.py +25 -24
  13. ai_edge_torch/generative/layers/model_config.py +3 -3
  14. ai_edge_torch/generative/layers/normalization.py +14 -3
  15. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  16. ai_edge_torch/generative/test/test_model_conversion_large.py +119 -0
  17. ai_edge_torch/version.py +1 -1
  18. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/METADATA +2 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/RECORD +22 -18
  20. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240923.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/top_level.txt +0 -0
@@ -161,9 +161,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
161
161
  ),
162
162
  ff_config=cfg.FeedForwardConfig(
163
163
  type=cfg.FeedForwardType.SEQUENTIAL,
164
- activation=cfg.ActivationConfig(
165
- cfg.ActivationType.SILU_GLU, gate_is_front=True
166
- ),
164
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
167
165
  intermediate_size=get_intermediate_size(idx),
168
166
  pre_ff_norm_config=norm_config,
169
167
  ),
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi3.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
+
18
+ import math
19
+ from typing import Tuple
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.gate_up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
34
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
35
+ pre_attn_norm="model.layers.{}.input_layernorm",
36
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
37
+ embedding="model.embed_tokens",
38
+ final_norm="model.norm",
39
+ lm_head="lm_head",
40
+ )
41
+
42
+ # max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
43
+ ROPE_SCALE_FACTOR = 32
44
+
45
+ # ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
46
+ # https://github.com/microsoft/LongRoPE, these values had been searched with
47
+ # min=1.0, step-0.01 to optimize the errors of sample dataset.
48
+ ROPE_SHORT_FACTOR = [
49
+ 1.0,
50
+ 1.0199999809265137,
51
+ 1.0299999713897705,
52
+ 1.0299999713897705,
53
+ 1.0499999523162842,
54
+ 1.0499999523162842,
55
+ 1.0499999523162842,
56
+ 1.0499999523162842,
57
+ 1.0499999523162842,
58
+ 1.0699999332427979,
59
+ 1.0999999046325684,
60
+ 1.1099998950958252,
61
+ 1.1599998474121094,
62
+ 1.1599998474121094,
63
+ 1.1699998378753662,
64
+ 1.2899998426437378,
65
+ 1.339999794960022,
66
+ 1.679999828338623,
67
+ 1.7899998426437378,
68
+ 1.8199998140335083,
69
+ 1.8499997854232788,
70
+ 1.8799997568130493,
71
+ 1.9099997282028198,
72
+ 1.9399996995925903,
73
+ 1.9899996519088745,
74
+ 2.0199997425079346,
75
+ 2.0199997425079346,
76
+ 2.0199997425079346,
77
+ 2.0199997425079346,
78
+ 2.0199997425079346,
79
+ 2.0199997425079346,
80
+ 2.0299997329711914,
81
+ 2.0299997329711914,
82
+ 2.0299997329711914,
83
+ 2.0299997329711914,
84
+ 2.0299997329711914,
85
+ 2.0299997329711914,
86
+ 2.0299997329711914,
87
+ 2.0299997329711914,
88
+ 2.0299997329711914,
89
+ 2.0799996852874756,
90
+ 2.0899996757507324,
91
+ 2.189999580383301,
92
+ 2.2199995517730713,
93
+ 2.5899994373321533,
94
+ 2.729999542236328,
95
+ 2.749999523162842,
96
+ 2.8399994373321533,
97
+ ]
98
+
99
+
100
+ def build_rope_cache(
101
+ size: int,
102
+ dim: int,
103
+ base: int = 10000,
104
+ condense_ratio: int = 1,
105
+ dtype: torch.dtype = torch.float32,
106
+ device: torch.device = None,
107
+ theta_factors: torch.Tensor = None,
108
+ scale: float = 1.0,
109
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
+
112
+ It's a modified version of attn_utils.build_rope_cache with additional
113
+ arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
114
+ Cos values with scaling factors for quick lookup during the inference.
115
+
116
+ Args:
117
+ size (int): The size of the built cache.
118
+ dim (int): Each sequence's dimmension.
119
+ base (int, optional): Rope base value. Defaults to 10000.
120
+ condense_ratio (int, optional): The ratio by which sequence indicies are
121
+ condensed. Defaults to 1.
122
+ dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
+ torch.float32.
124
+ device (torch.device, optional): Output tensor's data type. Defaults to
125
+ None in which case "cpu" is used.
126
+ theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
+ scale the theta values. Defaults to None.
128
+ scale (float, optional): A float used to scale the rope values. Defaults
129
+ to 1.0.
130
+
131
+ Returns:
132
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
+ """
134
+ if device is None:
135
+ device = torch.device('cpu')
136
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
+ if theta_factors is not None:
138
+ theta = theta / theta_factors
139
+ seq_idx = torch.arange(size) / condense_ratio
140
+ idx_theta = torch.outer(seq_idx, theta)
141
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
142
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
143
+ return cos, sin
144
+
145
+
146
+ class Phi3_5Mini(nn.Module):
147
+ """A Phi-3.5 model built from the Edge Generative API layers."""
148
+
149
+ def __init__(self, config: cfg.ModelConfig):
150
+ super().__init__()
151
+
152
+ # Construct model layers.
153
+ self.lm_head = nn.Linear(
154
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
155
+ )
156
+ self.tok_embedding = nn.Embedding(
157
+ config.vocab_size, config.embedding_dim, padding_idx=0
158
+ )
159
+ # Phi-3.5 has only one block config.
160
+ block_config = config.block_config(0)
161
+ self.transformer_blocks = nn.ModuleList(
162
+ attention.TransformerBlock(block_config, config)
163
+ for _ in range(config.num_layers)
164
+ )
165
+ self.final_norm = builder.build_norm(
166
+ config.embedding_dim,
167
+ config.final_norm_config,
168
+ )
169
+ attn_config = block_config.attn_config
170
+ self.rope_cache = build_rope_cache(
171
+ size=config.kv_cache_max,
172
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
+ base=10_000,
174
+ condense_ratio=1,
175
+ dtype=torch.float32,
176
+ device=torch.device("cpu"),
177
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
178
+ scale=math.sqrt(
179
+ 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
180
+ ),
181
+ )
182
+ self.mask_cache = attn_utils.build_causal_mask_cache(
183
+ size=config.kv_cache_max,
184
+ dtype=torch.float32,
185
+ device=torch.device("cpu"),
186
+ )
187
+ self.config = config
188
+
189
+ @torch.inference_mode
190
+ def forward(
191
+ self,
192
+ tokens: torch.Tensor,
193
+ input_pos: torch.Tensor,
194
+ kv_cache: kv_utils.KVCache,
195
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
196
+ _, seq_len = tokens.size()
197
+ assert self.config.max_seq_len >= seq_len, (
198
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
199
+ f" {self.config.max_seq_len}"
200
+ )
201
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
202
+ "The number of transformer blocks and the number of KV cache entries"
203
+ " must be the same."
204
+ )
205
+
206
+ cos, sin = self.rope_cache
207
+ cos = cos.index_select(0, input_pos)
208
+ sin = sin.index_select(0, input_pos)
209
+ mask = self.mask_cache.index_select(2, input_pos)
210
+ mask = mask[:, :, :, : self.config.kv_cache_max]
211
+
212
+ x = self.tok_embedding(tokens)
213
+
214
+ updated_kv_entires = []
215
+ for i, block in enumerate(self.transformer_blocks):
216
+ kv_entry = kv_cache.caches[i] if kv_cache else None
217
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
218
+ if kv_entry:
219
+ updated_kv_entires.append(kv_entry)
220
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
221
+
222
+ x = self.final_norm(x)
223
+ logits = self.lm_head(x) # (b, t, vocab_size)
224
+ return {"logits": logits, "kv_cache": updated_kv_cache}
225
+
226
+
227
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
+ """Returns the model config for a Phi-3.5 model.
229
+
230
+ Args:
231
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
232
+ is 1024.
233
+
234
+ Returns:
235
+ The model config for a Phi-2 model.
236
+ """
237
+ attn_config = cfg.AttentionConfig(
238
+ num_heads=32,
239
+ head_dim=96,
240
+ num_query_groups=32,
241
+ rotary_percentage=1.0,
242
+ qkv_transpose_before_split=True,
243
+ )
244
+ ff_config = cfg.FeedForwardConfig(
245
+ type=cfg.FeedForwardType.SEQUENTIAL,
246
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
247
+ intermediate_size=8192,
248
+ )
249
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
250
+ block_config = cfg.TransformerBlockConfig(
251
+ attn_config=attn_config,
252
+ ff_config=ff_config,
253
+ pre_attention_norm_config=norm_config,
254
+ post_attention_norm_config=norm_config,
255
+ )
256
+ config = cfg.ModelConfig(
257
+ vocab_size=32064,
258
+ num_layers=32,
259
+ max_seq_len=4096,
260
+ kv_cache_max_len=kv_cache_max_len,
261
+ embedding_dim=3072,
262
+ block_configs=block_config,
263
+ final_norm_config=norm_config,
264
+ enable_hlfb=True,
265
+ )
266
+ return config
267
+
268
+
269
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
270
+ config = get_model_config(kv_cache_max_len)
271
+ config.vocab_size = 128
272
+ config.num_layers = 2
273
+ config.max_seq_len = 2 * kv_cache_max_len
274
+ # Phi-3.5 has only one block config.
275
+ config.block_config(0).ff_config.intermediate_size = 128
276
+ return config
277
+
278
+
279
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
280
+ """Instantiates the model instance and load checkpoint if provided."""
281
+ config = get_model_config(**kwargs)
282
+ model = Phi3_5Mini(config)
283
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
284
+ loader.load(model)
285
+ model.eval()
286
+ return model
@@ -27,13 +27,13 @@ _PROMPTS = flags.DEFINE_multi_string(
27
27
  "Instruct: Write an email about the weather Output:",
28
28
  "The input prompts to generate answers.",
29
29
  )
30
-
31
30
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
31
  "max_new_tokens",
33
32
  30,
34
33
  "The maximum size of the generated tokens.",
35
34
  )
36
35
 
36
+
37
37
  def main(_):
38
38
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
39
  verifier.log_msg("Loading the original model from", checkpoint)
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-3.5 model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.phi import phi3
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Instruct: Write an email about the weather Output:",
29
+ "The input prompts to generate answers.",
30
+ )
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
36
+
37
+
38
+ def main(_):
39
+ checkpoint = "microsoft/Phi-3.5-mini-instruct"
40
+ verifier.log_msg("Loading the original model from", checkpoint)
41
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
+ wrapper_model = verifier.ModelWrapper(
44
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
+ hf_generation_config=generation_config,
46
+ )
47
+
48
+ # Locate the cached dir.
49
+ cached_config_file = transformers.utils.cached_file(
50
+ checkpoint, transformers.utils.CONFIG_NAME
51
+ )
52
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
54
+ reauthored_model = phi3.build_model(reauthored_checkpoint)
55
+
56
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
57
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
+
59
+ verifier.verify_reauthored_model(
60
+ original_model=wrapper_model,
61
+ reauthored_model=reauthored_model,
62
+ tokenizer=tokenizer,
63
+ generate_prompts=_PROMPTS.value,
64
+ )
65
+
66
+
67
+ if __name__ == "__main__":
68
+ app.run(main)
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
48
 
49
49
 
50
50
  class CLIP(nn.Module):
51
- """CLIP text encoder
51
+ """CLIP text encoder.
52
52
 
53
53
  For details, see https://arxiv.org/abs/2103.00020
54
54
  """
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
86
86
 
87
87
 
88
88
  def get_model_config() -> cfg.ModelConfig:
89
+ """Get configs for the CLIP of Stable Diffusion v1.5."""
89
90
  max_seq_len = 77
90
91
  vocab_size = 49408
91
92
  num_layers = 12
@@ -132,3 +133,53 @@ def get_model_config() -> cfg.ModelConfig:
132
133
  )
133
134
 
134
135
  return config
136
+
137
+
138
+ def get_fake_model_config() -> cfg.ModelConfig:
139
+ """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
140
+ max_seq_len = 6
141
+ vocab_size = 100
142
+ num_layers = 2
143
+ num_heads = 12
144
+ num_query_groups = 12
145
+ embedding_dim = 24
146
+
147
+ attn_config = cfg.AttentionConfig(
148
+ num_heads=num_heads,
149
+ head_dim=embedding_dim // num_heads,
150
+ num_query_groups=num_query_groups,
151
+ rotary_percentage=0.0,
152
+ qkv_use_bias=True,
153
+ qkv_transpose_before_split=True,
154
+ qkv_fused_interleaved=False,
155
+ output_proj_use_bias=True,
156
+ enable_kv_cache=False,
157
+ )
158
+
159
+ ff_config = cfg.FeedForwardConfig(
160
+ type=cfg.FeedForwardType.SEQUENTIAL,
161
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
162
+ intermediate_size=embedding_dim * 4,
163
+ use_bias=True,
164
+ )
165
+
166
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
167
+
168
+ block_config = cfg.TransformerBlockConfig(
169
+ attn_config=attn_config,
170
+ ff_config=ff_config,
171
+ pre_attention_norm_config=norm_config,
172
+ post_attention_norm_config=norm_config,
173
+ )
174
+
175
+ config = cfg.ModelConfig(
176
+ vocab_size=vocab_size,
177
+ num_layers=num_layers,
178
+ max_seq_len=max_seq_len,
179
+ embedding_dim=embedding_dim,
180
+ block_configs=block_config,
181
+ final_norm_config=norm_config,
182
+ enable_hlfb=True,
183
+ )
184
+
185
+ return config
@@ -324,3 +324,59 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
324
324
  mid_block_config=mid_block_config,
325
325
  )
326
326
  return config
327
+
328
+
329
+ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
330
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
331
+ in_channels = 3
332
+ latent_channels = 4
333
+ out_channels = 3
334
+ block_out_channels = [2, 4]
335
+ scaling_factor = 0.18215
336
+ layers_per_block = 2
337
+
338
+ norm_config = layers_cfg.NormalizationConfig(
339
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
340
+ )
341
+
342
+ att_config = unet_cfg.AttentionBlock2DConfig(
343
+ dim=block_out_channels[-1],
344
+ normalization_config=norm_config,
345
+ attention_config=layers_cfg.AttentionConfig(
346
+ num_heads=1,
347
+ head_dim=block_out_channels[-1],
348
+ num_query_groups=1,
349
+ qkv_use_bias=True,
350
+ output_proj_use_bias=True,
351
+ enable_kv_cache=False,
352
+ qkv_transpose_before_split=True,
353
+ qkv_fused_interleaved=False,
354
+ rotary_percentage=0.0,
355
+ ),
356
+ enable_hlfb=False,
357
+ )
358
+
359
+ mid_block_config = unet_cfg.MidBlock2DConfig(
360
+ in_channels=block_out_channels[-1],
361
+ normalization_config=norm_config,
362
+ activation_config=layers_cfg.ActivationConfig(
363
+ layers_cfg.ActivationType.SILU
364
+ ),
365
+ num_layers=1,
366
+ attention_block_config=att_config,
367
+ )
368
+
369
+ config = unet_cfg.AutoEncoderConfig(
370
+ in_channels=in_channels,
371
+ latent_channels=latent_channels,
372
+ out_channels=out_channels,
373
+ activation_config=layers_cfg.ActivationConfig(
374
+ layers_cfg.ActivationType.SILU
375
+ ),
376
+ block_out_channels=block_out_channels,
377
+ scaling_factor=scaling_factor,
378
+ layers_per_block=layers_per_block,
379
+ normalization_config=norm_config,
380
+ mid_block_config=mid_block_config,
381
+ )
382
+ return config
@@ -603,7 +603,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
603
603
  # Transformer configs.
604
604
  transformer_num_attention_heads = 8
605
605
  transformer_batch_size = batch_size
606
- transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
606
+ transformer_cross_attention_dim = 768 # Embedding from CLIP model
607
607
  transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
608
608
  layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
609
609
  )
@@ -645,3 +645,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
645
645
  final_norm_config=final_norm_config,
646
646
  final_activation_type=final_activation_type,
647
647
  )
648
+
649
+
650
+ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
651
+ """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
652
+
653
+ Args:
654
+ batch_size (int): the batch size of input.
655
+
656
+ Retruns:
657
+ The configuration of diffusion model of Stable Diffusion v1.5.
658
+ """
659
+ in_channels = 4
660
+ out_channels = 4
661
+ block_out_channels = [2, 4, 8, 8]
662
+ layers_per_block = 1
663
+ downsample_padding = 1
664
+
665
+ # Residual configs.
666
+ residual_norm_config = layers_cfg.NormalizationConfig(
667
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
668
+ )
669
+ residual_activation_type = layers_cfg.ActivationType.SILU
670
+
671
+ # Transformer configs.
672
+ transformer_num_attention_heads = 1
673
+ transformer_batch_size = batch_size
674
+ transformer_cross_attention_dim = 4 # Embedding from CLIP model
675
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
676
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
677
+ )
678
+ transformer_norm_config = layers_cfg.NormalizationConfig(
679
+ layers_cfg.NormalizationType.LAYER_NORM
680
+ )
681
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
682
+
683
+ # Time embedding configs.
684
+ time_embedding_dim = 2
685
+ time_embedding_blocks_dim = 4
686
+
687
+ # Mid block configs.
688
+ mid_block_layers = 1
689
+
690
+ # Finaly layer configs.
691
+ final_norm_config = layers_cfg.NormalizationConfig(
692
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
693
+ )
694
+ final_activation_type = layers_cfg.ActivationType.SILU
695
+
696
+ return unet_cfg.DiffusionModelConfig(
697
+ in_channels=in_channels,
698
+ out_channels=out_channels,
699
+ block_out_channels=block_out_channels,
700
+ layers_per_block=layers_per_block,
701
+ downsample_padding=downsample_padding,
702
+ residual_norm_config=residual_norm_config,
703
+ residual_activation_type=residual_activation_type,
704
+ transformer_batch_size=transformer_batch_size,
705
+ transformer_num_attention_heads=transformer_num_attention_heads,
706
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
707
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
708
+ transformer_norm_config=transformer_norm_config,
709
+ transformer_ff_activation_type=transformer_ff_activation_type,
710
+ mid_block_layers=mid_block_layers,
711
+ time_embedding_dim=time_embedding_dim,
712
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
713
+ final_norm_config=final_norm_config,
714
+ final_activation_type=final_activation_type,
715
+ )
@@ -0,0 +1,105 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from absl import app
17
+ import ai_edge_torch
18
+ from ai_edge_torch import lowertools
19
+ from ai_edge_torch.generative.examples.test_models import toy_model
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import torch
23
+
24
+ KV_CACHE_MAX_LEN = 100
25
+
26
+
27
+ def convert_toy_model(_) -> None:
28
+ """Converts a toy model to tflite."""
29
+ model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
30
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
31
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
32
+ print('running an inference')
33
+ print(
34
+ model.forward(
35
+ idx,
36
+ input_pos,
37
+ )
38
+ )
39
+
40
+ # Convert model to tflite.
41
+ print('converting model to tflite')
42
+ edge_model = ai_edge_torch.convert(
43
+ model,
44
+ (
45
+ idx,
46
+ input_pos,
47
+ ),
48
+ )
49
+ edge_model.export('/tmp/toy_model.tflite')
50
+
51
+
52
+ def _export_stablehlo_mlir(model, args):
53
+ ep = torch.export.export(model, args)
54
+ return lowertools.exported_program_to_mlir_text(ep)
55
+
56
+
57
+ def convert_toy_model_with_kv_cache(_) -> None:
58
+ """Converts a toy model with kv cache to tflite."""
59
+ dump_mlir = False
60
+
61
+ config = toy_model_with_kv_cache.get_model_config()
62
+ model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
63
+ model.eval()
64
+ print('running an inference')
65
+ kv = kv_utils.KVCache.from_model_config(config)
66
+
67
+ tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
68
+ decode_token, decode_input_pos = (
69
+ toy_model_with_kv_cache.get_sample_decode_inputs()
70
+ )
71
+ print(model.forward(tokens, input_pos, kv))
72
+
73
+ if dump_mlir:
74
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
75
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
76
+ f.write(mlir_text)
77
+
78
+ # Convert model to tflite with 2 signatures (prefill + decode).
79
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
80
+ edge_model = (
81
+ ai_edge_torch.signature(
82
+ 'prefill',
83
+ model,
84
+ sample_kwargs={
85
+ 'tokens': tokens,
86
+ 'input_pos': input_pos,
87
+ 'kv_cache': kv,
88
+ },
89
+ )
90
+ .signature(
91
+ 'decode',
92
+ model,
93
+ sample_kwargs={
94
+ 'tokens': decode_token,
95
+ 'input_pos': decode_input_pos,
96
+ 'kv_cache': kv,
97
+ },
98
+ )
99
+ .convert()
100
+ )
101
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
102
+
103
+
104
+ if __name__ == '__main__':
105
+ app.run(convert_toy_model)
@@ -15,13 +15,12 @@
15
15
  # A toy example which has a single-layer transformer block.
16
16
  from typing import Tuple
17
17
 
18
- import ai_edge_torch
18
+ from ai_edge_torch.generative.layers import builder
19
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
- import ai_edge_torch.generative.layers.builder as builder
22
21
  import ai_edge_torch.generative.layers.model_config as cfg
23
22
  import torch
24
- import torch.nn as nn
23
+ from torch import nn
25
24
 
26
25
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
27
26
  KV_CACHE_MAX_LEN = 100
@@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig:
149
148
  final_norm_config=norm_config,
150
149
  )
151
150
  return config
152
-
153
-
154
- def define_and_run() -> None:
155
- model = ToySingleLayerModel(get_model_config())
156
- idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
157
- input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
158
- print('running an inference')
159
- print(
160
- model.forward(
161
- idx,
162
- input_pos,
163
- )
164
- )
165
-
166
- # Convert model to tflite.
167
- print('converting model to tflite')
168
- edge_model = ai_edge_torch.convert(
169
- model,
170
- (
171
- idx,
172
- input_pos,
173
- ),
174
- )
175
- edge_model.export('/tmp/toy_model.tflite')
176
-
177
-
178
- if __name__ == '__main__':
179
- define_and_run()
@@ -17,15 +17,14 @@
17
17
 
18
18
  from typing import Tuple
19
19
 
20
- import ai_edge_torch
21
- from ai_edge_torch import lowertools
20
+ from absl import app
22
21
  from ai_edge_torch.generative.layers import attention
23
22
  from ai_edge_torch.generative.layers import builder
24
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import torch
28
- import torch.nn as nn
27
+ from torch import nn
29
28
 
30
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
30
 
@@ -87,11 +86,6 @@ class ToyModelWithKVCache(torch.nn.Module):
87
86
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
88
87
 
89
88
 
90
- def _export_stablehlo_mlir(model, args):
91
- ep = torch.export.export(model, args)
92
- return lowertools.exported_program_to_mlir_text(ep)
93
-
94
-
95
89
  def get_model_config() -> cfg.ModelConfig:
96
90
  attn_config = cfg.AttentionConfig(
97
91
  num_heads=32,
@@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
127
  tokens = torch.tensor([[1]], dtype=torch.int)
134
128
  input_pos = torch.tensor([10])
135
129
  return tokens, input_pos
136
-
137
-
138
- def define_and_run() -> None:
139
- dump_mlir = False
140
-
141
- config = get_model_config()
142
- model = ToyModelWithExternalKV(config)
143
- model.eval()
144
- print('running an inference')
145
- kv = kv_utils.KVCache.from_model_config(config)
146
-
147
- tokens, input_pos = get_sample_prefill_inputs()
148
- decode_token, decode_input_pos = get_sample_decode_inputs()
149
- print(model.forward(tokens, input_pos, kv))
150
-
151
- if dump_mlir:
152
- mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
- with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
154
- f.write(mlir_text)
155
-
156
- # Convert model to tflite with 2 signatures (prefill + decode).
157
- print('converting toy model to tflite with 2 signatures (prefill + decode)')
158
- edge_model = (
159
- ai_edge_torch.signature(
160
- 'prefill',
161
- model,
162
- sample_kwargs={
163
- 'tokens': tokens,
164
- 'input_pos': input_pos,
165
- 'kv_cache': kv,
166
- },
167
- )
168
- .signature(
169
- 'decode',
170
- model,
171
- sample_kwargs={
172
- 'tokens': decode_token,
173
- 'input_pos': decode_input_pos,
174
- 'kv_cache': kv,
175
- },
176
- )
177
- .convert()
178
- )
179
- edge_model.export('/tmp/toy_external_kv_cache.tflite')
180
-
181
-
182
- if __name__ == '__main__':
183
- define_and_run()
@@ -23,34 +23,35 @@ from torch import nn
23
23
  import torch.nn.functional as F
24
24
 
25
25
 
26
- def build_glu(
27
- act: Callable[[torch.Tensor], torch.Tensor], gate_is_front: bool = False
28
- ) -> Callable[[torch.Tensor], torch.Tensor]:
29
- """Builds an activation function with GLU (Gated Linear Unit).
26
+ class GeGLU(nn.Module):
27
+ """GeGLU is an activation function which is a variant of GELU.
30
28
 
31
- If gate_is_front is True,
32
- f(x) = act(x) * y
33
- otherwise,
34
- f(x) = x * act(y),
35
- where x is the first half of the input and y is the second half of the input.
29
+ GeGLU(x) = (xW+b) * GELU(xV+c)
30
+ See: https://arxiv.org/abs/2002.05202v1
31
+ """
36
32
 
37
- Args:
38
- act (Callable[[torch.Tensor], torch.Tensor]): activation function to apply
39
- to the gate.
40
- gate_is_front: whether the gate is in front half of the input. Other part is
41
- the output in GLU.
33
+ def __init__(self, d_in: int, d_out: int):
34
+ super().__init__()
35
+ self.proj = nn.Linear(d_in, d_out * 2)
42
36
 
43
- Returns:
44
- A callable activation function with GLU.
37
+ def forward(self, x: torch.Tensor):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
42
+ class SwiGLU(nn.Module):
43
+ """SwiGLU is an activation function which is a variant of GLU.
44
+
45
+ SwiGLU is same as SiLU_GLU, because The SiLU function is also known as the
46
+ swish function.
47
+
48
+ SwiGLU(x) = Swish(xW+b) * (xV+c)
49
+ See: https://paperswithcode.com/method/swiglu
45
50
  """
46
51
 
47
- def _glu(x):
52
+ def forward(self, x: torch.Tensor):
48
53
  x, y = x.chunk(2, dim=-1)
49
- if gate_is_front:
50
- return act(x) * y
51
- return x * act(y)
52
-
53
- return _glu
54
+ return F.silu(x) * y
54
55
 
55
56
 
56
57
  def build_norm(dim: int, config: cfg.NormalizationConfig):
@@ -151,10 +152,10 @@ def get_activation(config: cfg.ActivationConfig):
151
152
  # See: https://github.com/hendrycks/GELUs
152
153
  return lambda x: x * F.sigmoid(1.702 * x)
153
154
  elif config.type == cfg.ActivationType.GE_GLU:
154
- return build_glu(F.gelu, config.gate_is_front)
155
+ return GeGLU(config.dim_in, config.dim_out)
155
156
  elif config.type == cfg.ActivationType.RELU:
156
157
  return F.relu
157
158
  elif config.type == cfg.ActivationType.SILU_GLU:
158
- return build_glu(F.silu, config.gate_is_front)
159
+ return SwiGLU()
159
160
  else:
160
161
  raise ValueError("Unsupported activation type.")
@@ -118,9 +118,9 @@ class AttentionConfig:
118
118
  @dataclass
119
119
  class ActivationConfig:
120
120
  type: ActivationType = ActivationType.LINEAR
121
- # Whether to GLU gate is the front part instead of the back part of input
122
- # when ActivationType is `GE_GLU` or `SILU_GLU`.
123
- gate_is_front: bool = False
121
+ # Dimension of input and output, used in GeGLU.
122
+ dim_in: Optional[int] = None
123
+ dim_out: Optional[int] = None
124
124
 
125
125
 
126
126
  @dataclass
@@ -183,8 +183,16 @@ def group_norm_with_hlfb(
183
183
  """
184
184
  x = torch.permute(x, (0, 2, 3, 1))
185
185
 
186
+ # TODO: b/366544750 - Change "reduction_axes" field as an array, rather than
187
+ # int32 when the bug is fixed.
186
188
  builder = StableHLOCompositeBuilder(
187
- name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
189
+ name="odml.group_norm",
190
+ attr={
191
+ "num_groups": num_groups,
192
+ "epsilon": eps,
193
+ "reduction_axes": 3,
194
+ "channel_axis": 3,
195
+ },
188
196
  )
189
197
  x, w, b = builder.mark_inputs(x, w, b)
190
198
  x = torch.permute(x, (0, 3, 1, 2))
@@ -206,7 +214,7 @@ def layer_norm_with_hlfb(
206
214
  """Layer Normalization with high-level function boundary enabled.
207
215
 
208
216
  Args:
209
- x (torch.Tensor): Input tensor for Layer Normalization.
217
+ x (torch.Tensor): Input tensor for Layer Normalization, with BCHW shape.
210
218
  w (torch.Tensor): The weight tensor for the normalization.
211
219
  b (torch.Tensor): The bias tensor for the normalization.
212
220
  eps (float): A small float value to ensure numerical stability.
@@ -216,7 +224,10 @@ def layer_norm_with_hlfb(
216
224
  Returns:
217
225
  The output tensor of Layer Normalization.
218
226
  """
219
- builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
227
+ builder = StableHLOCompositeBuilder(
228
+ name="odml.group_norm",
229
+ attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
+ )
220
231
  x, w, b = builder.mark_inputs(x, w, b)
221
232
  if use_input_shape:
222
233
  normalized_shape = x.shape
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  from ai_edge_torch.generative.layers.attention import CrossAttention
19
19
  from ai_edge_torch.generative.layers.attention import SelfAttention
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
416
416
  time_emb: Optional[torch.Tensor] = None,
417
417
  context_tensor: Optional[torch.Tensor] = None,
418
418
  output_hidden_states: bool = False,
419
- ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
419
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
420
420
  """Forward function of the DownEncoderBlock2D.
421
421
 
422
422
  Args:
@@ -21,7 +21,11 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
23
  from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.phi import phi3
24
25
  from ai_edge_torch.generative.examples.smollm import smollm
26
+ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
27
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
28
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
25
29
  from ai_edge_torch.generative.layers import kv_cache
26
30
  from ai_edge_torch.generative.test import utils as test_utils
27
31
  import numpy as np
@@ -109,6 +113,17 @@ class TestModelConversion(googletest.TestCase):
109
113
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
110
114
  )
111
115
 
116
+ @googletest.skipIf(
117
+ ai_edge_config.Config.use_torch_xla,
118
+ reason="tests with custom ops are not supported on oss",
119
+ )
120
+ def test_phi3(self):
121
+ config = phi3.get_fake_model_config()
122
+ pytorch_model = phi3.Phi3_5Mini(config).eval()
123
+ self._test_model(
124
+ config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
125
+ )
126
+
112
127
  @googletest.skipIf(
113
128
  ai_edge_config.Config.use_torch_xla,
114
129
  reason="tests with custom ops are not supported on oss",
@@ -127,6 +142,110 @@ class TestModelConversion(googletest.TestCase):
127
142
  pytorch_model = openelm.OpenELM(config).eval()
128
143
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
129
144
 
145
+ @googletest.skipIf(
146
+ ai_edge_config.Config.use_torch_xla,
147
+ reason="tests with custom ops are not supported on oss",
148
+ )
149
+ def test_stable_diffusion_clip(self):
150
+ config = sd_clip.get_fake_model_config()
151
+ prompt_tokens = torch.from_numpy(
152
+ np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
153
+ )
154
+
155
+ pytorch_model = sd_clip.CLIP(config).eval()
156
+ torch_output = pytorch_model(prompt_tokens)
157
+
158
+ edge_model = ai_edge_torch.signature(
159
+ "encode", pytorch_model, (prompt_tokens,)
160
+ ).convert()
161
+ edge_model.set_interpreter_builder(
162
+ self._interpreter_builder(edge_model.tflite_model())
163
+ )
164
+ edge_output = edge_model(
165
+ prompt_tokens.numpy(),
166
+ signature_name="encode",
167
+ )
168
+ self.assertTrue(
169
+ np.allclose(
170
+ edge_output,
171
+ torch_output.detach().numpy(),
172
+ atol=1e-4,
173
+ rtol=1e-5,
174
+ )
175
+ )
176
+
177
+ @googletest.skipIf(
178
+ ai_edge_config.Config.use_torch_xla,
179
+ reason="tests with custom ops are not supported on oss",
180
+ )
181
+ def test_stable_diffusion_diffusion(self):
182
+ config = sd_diffusion.get_fake_model_config(2)
183
+ latents = torch.from_numpy(
184
+ np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
185
+ )
186
+ context = torch.from_numpy(
187
+ np.random.normal(size=(2, 4, 4)).astype(np.float32)
188
+ )
189
+ time_embedding = torch.from_numpy(
190
+ np.random.normal(size=(2, 2)).astype(np.float32)
191
+ )
192
+
193
+ pytorch_model = sd_diffusion.Diffusion(config).eval()
194
+ torch_output = pytorch_model(latents, context, time_embedding)
195
+
196
+ edge_model = ai_edge_torch.signature(
197
+ "diffusion", pytorch_model, (latents, context, time_embedding)
198
+ ).convert()
199
+ edge_model.set_interpreter_builder(
200
+ self._interpreter_builder(edge_model.tflite_model())
201
+ )
202
+ edge_output = edge_model(
203
+ latents.numpy(),
204
+ context.numpy(),
205
+ time_embedding.numpy(),
206
+ signature_name="diffusion",
207
+ )
208
+ self.assertTrue(
209
+ np.allclose(
210
+ edge_output,
211
+ torch_output.detach().numpy(),
212
+ atol=1e-4,
213
+ rtol=1e-5,
214
+ )
215
+ )
216
+
217
+ @googletest.skipIf(
218
+ ai_edge_config.Config.use_torch_xla,
219
+ reason="tests with custom ops are not supported on oss",
220
+ )
221
+ def test_stable_diffusion_decoder(self):
222
+ config = sd_decoder.get_fake_model_config()
223
+ latents = torch.from_numpy(
224
+ np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
225
+ )
226
+
227
+ pytorch_model = sd_decoder.Decoder(config).eval()
228
+ torch_output = pytorch_model(latents)
229
+
230
+ edge_model = ai_edge_torch.signature(
231
+ "decode", pytorch_model, (latents,)
232
+ ).convert()
233
+ edge_model.set_interpreter_builder(
234
+ self._interpreter_builder(edge_model.tflite_model())
235
+ )
236
+ edge_output = edge_model(
237
+ latents.numpy(),
238
+ signature_name="decode",
239
+ )
240
+ self.assertTrue(
241
+ np.allclose(
242
+ edge_output,
243
+ torch_output.detach().numpy(),
244
+ atol=1e-4,
245
+ rtol=1e-5,
246
+ )
247
+ )
248
+
130
249
 
131
250
  if __name__ == "__main__":
132
251
  googletest.main()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240923"
16
+ __version__ = "0.3.0.dev20240925"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240923
3
+ Version: 0.3.0.dev20240925
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,6 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
+ Requires-Dist: ai-edge-litert-nightly
33
34
  Requires-Dist: ai-edge-quantizer-nightly
34
35
 
35
36
  Library that supports converting PyTorch models into a .tflite format, which can
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=oxtOOEY9LJkV5vRrgr1EoSjAjuetYVNq7WQqMuauRkc,706
6
+ ai_edge_torch/version.py,sha256=UXj1-90S3RDoHwYSmy9VdMC0Sm3EHt9ESLZbi3hnWus,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -48,22 +48,25 @@ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax
48
48
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
49
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
51
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
51
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
52
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
53
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
54
55
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
55
56
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
56
- ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
57
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=DIDzpG8DZkWDcWsAVkcxzxIC3U3352uVI3zMoYZD16U,9554
58
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=5pQ0Bt8vGl8uTpkgXvOx8G7_rju0Gi8mIEr5NtRSAbs,2145
59
+ ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=o1UTqpimkeX3MDjgdG1QTQkoZHvCEnGClA0J0WB3wJ4,2328
57
60
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
61
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
59
62
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
60
63
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
61
64
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
62
65
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
63
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
66
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=2RMi5UmfMT4Ep68ZLJsqF-fMvEumNVkIwqtsRli9HhA,6068
64
67
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
65
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
66
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
68
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ZTRD56e8MsdGPJr7vpLa4Ju_BFw_b-FUgXgd-SO5MBw,15665
69
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6FAnevL8ZfCK2YCSPivarUH0Z8wGKSmnPpJNC0OI5A8,33680
67
70
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
68
71
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
69
72
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -78,8 +81,9 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F7
78
81
  ai_edge_torch/generative/examples/t5/t5.py,sha256=OZ67knK-UB1dBjxydG-Jwkp0Z3FzOCqGPTdg5aBFu4w,21328
79
82
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
80
83
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
82
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
84
+ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
85
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LTuzres5DHmrMT6U9rCrGf6vmR9SmopmB8sO6Cd2NxQ,5255
86
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xDYTh4m3vBEb6r3_ERhmj5qILW7YdVDAnZ-fitgYONg,4450
83
87
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
88
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
85
89
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
@@ -89,15 +93,15 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkD
89
93
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
90
94
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
91
95
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
92
- ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
96
+ ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
93
97
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
94
98
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
95
- ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
96
- ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
99
+ ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
100
+ ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
97
101
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
98
102
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
99
103
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
100
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
104
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
101
105
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
102
106
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
103
107
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -111,7 +115,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
111
115
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
112
116
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
113
117
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
114
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
118
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IzW2HjXS2-zePZM-qEuXL4zclnGvYsNw-6tuDSeNna4,8163
115
119
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
116
120
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
117
121
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -166,8 +170,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
166
170
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
167
171
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
168
172
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
169
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/METADATA,sha256=BgwLxDJ3AOPVn0fkngAQpf3YdmShufhMt3bANFevtiQ,1859
171
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
- ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/RECORD,,
173
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
174
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/METADATA,sha256=5KsshdZ4-3X193HkoO2ukceyDEdWGvb8ZEMcw88qt7k,1897
175
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
176
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
177
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/RECORD,,