ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240925__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (19) hide show
  1. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  2. ai_edge_torch/generative/examples/phi/phi3.py +286 -0
  3. ai_edge_torch/generative/examples/phi/verify.py +0 -1
  4. ai_edge_torch/generative/examples/phi/verify_phi3.py +68 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/clip.py +52 -1
  6. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +56 -0
  7. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +69 -1
  8. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  9. ai_edge_torch/generative/examples/test_models/toy_model.py +2 -31
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +2 -56
  11. ai_edge_torch/generative/layers/normalization.py +2 -2
  12. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  13. ai_edge_torch/generative/test/test_model_conversion_large.py +119 -0
  14. ai_edge_torch/version.py +1 -1
  15. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/METADATA +1 -1
  16. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/RECORD +19 -15
  17. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/LICENSE +0 -0
  18. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/WHEEL +0 -0
  19. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240925.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Phi-3.5 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import converter
25
+
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 1024,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1280,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
51
+
52
+
53
+ def main(_):
54
+ pytorch_model = phi3.build_model(
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
+ )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi3_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
+ converter.convert_to_tflite(
60
+ pytorch_model,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
+ quantize=_QUANTIZE.value,
64
+ )
65
+
66
+
67
+ if __name__ == '__main__':
68
+ app.run(main)
@@ -0,0 +1,286 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
+
18
+ import math
19
+ from typing import Tuple
20
+
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
+ import ai_edge_torch.generative.layers.model_config as cfg
26
+ import ai_edge_torch.generative.utilities.loader as loading_utils
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.gate_up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
34
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
35
+ pre_attn_norm="model.layers.{}.input_layernorm",
36
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
37
+ embedding="model.embed_tokens",
38
+ final_norm="model.norm",
39
+ lm_head="lm_head",
40
+ )
41
+
42
+ # max_position_embeddings / original_max_position_embeddings in Phi-3.5 config.
43
+ ROPE_SCALE_FACTOR = 32
44
+
45
+ # ROPE short factor in Phi-3.5 config. According to LOPE paper and its code in
46
+ # https://github.com/microsoft/LongRoPE, these values had been searched with
47
+ # min=1.0, step-0.01 to optimize the errors of sample dataset.
48
+ ROPE_SHORT_FACTOR = [
49
+ 1.0,
50
+ 1.0199999809265137,
51
+ 1.0299999713897705,
52
+ 1.0299999713897705,
53
+ 1.0499999523162842,
54
+ 1.0499999523162842,
55
+ 1.0499999523162842,
56
+ 1.0499999523162842,
57
+ 1.0499999523162842,
58
+ 1.0699999332427979,
59
+ 1.0999999046325684,
60
+ 1.1099998950958252,
61
+ 1.1599998474121094,
62
+ 1.1599998474121094,
63
+ 1.1699998378753662,
64
+ 1.2899998426437378,
65
+ 1.339999794960022,
66
+ 1.679999828338623,
67
+ 1.7899998426437378,
68
+ 1.8199998140335083,
69
+ 1.8499997854232788,
70
+ 1.8799997568130493,
71
+ 1.9099997282028198,
72
+ 1.9399996995925903,
73
+ 1.9899996519088745,
74
+ 2.0199997425079346,
75
+ 2.0199997425079346,
76
+ 2.0199997425079346,
77
+ 2.0199997425079346,
78
+ 2.0199997425079346,
79
+ 2.0199997425079346,
80
+ 2.0299997329711914,
81
+ 2.0299997329711914,
82
+ 2.0299997329711914,
83
+ 2.0299997329711914,
84
+ 2.0299997329711914,
85
+ 2.0299997329711914,
86
+ 2.0299997329711914,
87
+ 2.0299997329711914,
88
+ 2.0299997329711914,
89
+ 2.0799996852874756,
90
+ 2.0899996757507324,
91
+ 2.189999580383301,
92
+ 2.2199995517730713,
93
+ 2.5899994373321533,
94
+ 2.729999542236328,
95
+ 2.749999523162842,
96
+ 2.8399994373321533,
97
+ ]
98
+
99
+
100
+ def build_rope_cache(
101
+ size: int,
102
+ dim: int,
103
+ base: int = 10000,
104
+ condense_ratio: int = 1,
105
+ dtype: torch.dtype = torch.float32,
106
+ device: torch.device = None,
107
+ theta_factors: torch.Tensor = None,
108
+ scale: float = 1.0,
109
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
110
+ """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
+
112
+ It's a modified version of attn_utils.build_rope_cache with additional
113
+ arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
114
+ Cos values with scaling factors for quick lookup during the inference.
115
+
116
+ Args:
117
+ size (int): The size of the built cache.
118
+ dim (int): Each sequence's dimmension.
119
+ base (int, optional): Rope base value. Defaults to 10000.
120
+ condense_ratio (int, optional): The ratio by which sequence indicies are
121
+ condensed. Defaults to 1.
122
+ dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
+ torch.float32.
124
+ device (torch.device, optional): Output tensor's data type. Defaults to
125
+ None in which case "cpu" is used.
126
+ theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
+ scale the theta values. Defaults to None.
128
+ scale (float, optional): A float used to scale the rope values. Defaults
129
+ to 1.0.
130
+
131
+ Returns:
132
+ Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
+ """
134
+ if device is None:
135
+ device = torch.device('cpu')
136
+ theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
+ if theta_factors is not None:
138
+ theta = theta / theta_factors
139
+ seq_idx = torch.arange(size) / condense_ratio
140
+ idx_theta = torch.outer(seq_idx, theta)
141
+ cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
142
+ sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
143
+ return cos, sin
144
+
145
+
146
+ class Phi3_5Mini(nn.Module):
147
+ """A Phi-3.5 model built from the Edge Generative API layers."""
148
+
149
+ def __init__(self, config: cfg.ModelConfig):
150
+ super().__init__()
151
+
152
+ # Construct model layers.
153
+ self.lm_head = nn.Linear(
154
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
155
+ )
156
+ self.tok_embedding = nn.Embedding(
157
+ config.vocab_size, config.embedding_dim, padding_idx=0
158
+ )
159
+ # Phi-3.5 has only one block config.
160
+ block_config = config.block_config(0)
161
+ self.transformer_blocks = nn.ModuleList(
162
+ attention.TransformerBlock(block_config, config)
163
+ for _ in range(config.num_layers)
164
+ )
165
+ self.final_norm = builder.build_norm(
166
+ config.embedding_dim,
167
+ config.final_norm_config,
168
+ )
169
+ attn_config = block_config.attn_config
170
+ self.rope_cache = build_rope_cache(
171
+ size=config.kv_cache_max,
172
+ dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
+ base=10_000,
174
+ condense_ratio=1,
175
+ dtype=torch.float32,
176
+ device=torch.device("cpu"),
177
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
178
+ scale=math.sqrt(
179
+ 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
180
+ ),
181
+ )
182
+ self.mask_cache = attn_utils.build_causal_mask_cache(
183
+ size=config.kv_cache_max,
184
+ dtype=torch.float32,
185
+ device=torch.device("cpu"),
186
+ )
187
+ self.config = config
188
+
189
+ @torch.inference_mode
190
+ def forward(
191
+ self,
192
+ tokens: torch.Tensor,
193
+ input_pos: torch.Tensor,
194
+ kv_cache: kv_utils.KVCache,
195
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
196
+ _, seq_len = tokens.size()
197
+ assert self.config.max_seq_len >= seq_len, (
198
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
199
+ f" {self.config.max_seq_len}"
200
+ )
201
+ assert len(self.transformer_blocks) == len(kv_cache.caches), (
202
+ "The number of transformer blocks and the number of KV cache entries"
203
+ " must be the same."
204
+ )
205
+
206
+ cos, sin = self.rope_cache
207
+ cos = cos.index_select(0, input_pos)
208
+ sin = sin.index_select(0, input_pos)
209
+ mask = self.mask_cache.index_select(2, input_pos)
210
+ mask = mask[:, :, :, : self.config.kv_cache_max]
211
+
212
+ x = self.tok_embedding(tokens)
213
+
214
+ updated_kv_entires = []
215
+ for i, block in enumerate(self.transformer_blocks):
216
+ kv_entry = kv_cache.caches[i] if kv_cache else None
217
+ x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
218
+ if kv_entry:
219
+ updated_kv_entires.append(kv_entry)
220
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
221
+
222
+ x = self.final_norm(x)
223
+ logits = self.lm_head(x) # (b, t, vocab_size)
224
+ return {"logits": logits, "kv_cache": updated_kv_cache}
225
+
226
+
227
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
+ """Returns the model config for a Phi-3.5 model.
229
+
230
+ Args:
231
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
232
+ is 1024.
233
+
234
+ Returns:
235
+ The model config for a Phi-2 model.
236
+ """
237
+ attn_config = cfg.AttentionConfig(
238
+ num_heads=32,
239
+ head_dim=96,
240
+ num_query_groups=32,
241
+ rotary_percentage=1.0,
242
+ qkv_transpose_before_split=True,
243
+ )
244
+ ff_config = cfg.FeedForwardConfig(
245
+ type=cfg.FeedForwardType.SEQUENTIAL,
246
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU_GLU),
247
+ intermediate_size=8192,
248
+ )
249
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
250
+ block_config = cfg.TransformerBlockConfig(
251
+ attn_config=attn_config,
252
+ ff_config=ff_config,
253
+ pre_attention_norm_config=norm_config,
254
+ post_attention_norm_config=norm_config,
255
+ )
256
+ config = cfg.ModelConfig(
257
+ vocab_size=32064,
258
+ num_layers=32,
259
+ max_seq_len=4096,
260
+ kv_cache_max_len=kv_cache_max_len,
261
+ embedding_dim=3072,
262
+ block_configs=block_config,
263
+ final_norm_config=norm_config,
264
+ enable_hlfb=True,
265
+ )
266
+ return config
267
+
268
+
269
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
270
+ config = get_model_config(kv_cache_max_len)
271
+ config.vocab_size = 128
272
+ config.num_layers = 2
273
+ config.max_seq_len = 2 * kv_cache_max_len
274
+ # Phi-3.5 has only one block config.
275
+ config.block_config(0).ff_config.intermediate_size = 128
276
+ return config
277
+
278
+
279
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
280
+ """Instantiates the model instance and load checkpoint if provided."""
281
+ config = get_model_config(**kwargs)
282
+ model = Phi3_5Mini(config)
283
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
284
+ loader.load(model)
285
+ model.eval()
286
+ return model
@@ -27,7 +27,6 @@ _PROMPTS = flags.DEFINE_multi_string(
27
27
  "Instruct: Write an email about the weather Output:",
28
28
  "The input prompts to generate answers.",
29
29
  )
30
-
31
30
  _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
31
  "max_new_tokens",
33
32
  30,
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-3.5 model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.phi import phi3
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Instruct: Write an email about the weather Output:",
29
+ "The input prompts to generate answers.",
30
+ )
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
36
+
37
+
38
+ def main(_):
39
+ checkpoint = "microsoft/Phi-3.5-mini-instruct"
40
+ verifier.log_msg("Loading the original model from", checkpoint)
41
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
+ wrapper_model = verifier.ModelWrapper(
44
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
+ hf_generation_config=generation_config,
46
+ )
47
+
48
+ # Locate the cached dir.
49
+ cached_config_file = transformers.utils.cached_file(
50
+ checkpoint, transformers.utils.CONFIG_NAME
51
+ )
52
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
54
+ reauthored_model = phi3.build_model(reauthored_checkpoint)
55
+
56
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
57
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
+
59
+ verifier.verify_reauthored_model(
60
+ original_model=wrapper_model,
61
+ reauthored_model=reauthored_model,
62
+ tokenizer=tokenizer,
63
+ generate_prompts=_PROMPTS.value,
64
+ )
65
+
66
+
67
+ if __name__ == "__main__":
68
+ app.run(main)
@@ -48,7 +48,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
48
 
49
49
 
50
50
  class CLIP(nn.Module):
51
- """CLIP text encoder
51
+ """CLIP text encoder.
52
52
 
53
53
  For details, see https://arxiv.org/abs/2103.00020
54
54
  """
@@ -86,6 +86,7 @@ class CLIP(nn.Module):
86
86
 
87
87
 
88
88
  def get_model_config() -> cfg.ModelConfig:
89
+ """Get configs for the CLIP of Stable Diffusion v1.5."""
89
90
  max_seq_len = 77
90
91
  vocab_size = 49408
91
92
  num_layers = 12
@@ -132,3 +133,53 @@ def get_model_config() -> cfg.ModelConfig:
132
133
  )
133
134
 
134
135
  return config
136
+
137
+
138
+ def get_fake_model_config() -> cfg.ModelConfig:
139
+ """Get fake configs for the CLIP of Stable Diffusion v1.5 for testing."""
140
+ max_seq_len = 6
141
+ vocab_size = 100
142
+ num_layers = 2
143
+ num_heads = 12
144
+ num_query_groups = 12
145
+ embedding_dim = 24
146
+
147
+ attn_config = cfg.AttentionConfig(
148
+ num_heads=num_heads,
149
+ head_dim=embedding_dim // num_heads,
150
+ num_query_groups=num_query_groups,
151
+ rotary_percentage=0.0,
152
+ qkv_use_bias=True,
153
+ qkv_transpose_before_split=True,
154
+ qkv_fused_interleaved=False,
155
+ output_proj_use_bias=True,
156
+ enable_kv_cache=False,
157
+ )
158
+
159
+ ff_config = cfg.FeedForwardConfig(
160
+ type=cfg.FeedForwardType.SEQUENTIAL,
161
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
162
+ intermediate_size=embedding_dim * 4,
163
+ use_bias=True,
164
+ )
165
+
166
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
167
+
168
+ block_config = cfg.TransformerBlockConfig(
169
+ attn_config=attn_config,
170
+ ff_config=ff_config,
171
+ pre_attention_norm_config=norm_config,
172
+ post_attention_norm_config=norm_config,
173
+ )
174
+
175
+ config = cfg.ModelConfig(
176
+ vocab_size=vocab_size,
177
+ num_layers=num_layers,
178
+ max_seq_len=max_seq_len,
179
+ embedding_dim=embedding_dim,
180
+ block_configs=block_config,
181
+ final_norm_config=norm_config,
182
+ enable_hlfb=True,
183
+ )
184
+
185
+ return config
@@ -324,3 +324,59 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
324
324
  mid_block_config=mid_block_config,
325
325
  )
326
326
  return config
327
+
328
+
329
+ def get_fake_model_config() -> unet_cfg.AutoEncoderConfig:
330
+ """Get fake configs for the Decoder of Stable Diffusion v1.5 for testing."""
331
+ in_channels = 3
332
+ latent_channels = 4
333
+ out_channels = 3
334
+ block_out_channels = [2, 4]
335
+ scaling_factor = 0.18215
336
+ layers_per_block = 2
337
+
338
+ norm_config = layers_cfg.NormalizationConfig(
339
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
340
+ )
341
+
342
+ att_config = unet_cfg.AttentionBlock2DConfig(
343
+ dim=block_out_channels[-1],
344
+ normalization_config=norm_config,
345
+ attention_config=layers_cfg.AttentionConfig(
346
+ num_heads=1,
347
+ head_dim=block_out_channels[-1],
348
+ num_query_groups=1,
349
+ qkv_use_bias=True,
350
+ output_proj_use_bias=True,
351
+ enable_kv_cache=False,
352
+ qkv_transpose_before_split=True,
353
+ qkv_fused_interleaved=False,
354
+ rotary_percentage=0.0,
355
+ ),
356
+ enable_hlfb=False,
357
+ )
358
+
359
+ mid_block_config = unet_cfg.MidBlock2DConfig(
360
+ in_channels=block_out_channels[-1],
361
+ normalization_config=norm_config,
362
+ activation_config=layers_cfg.ActivationConfig(
363
+ layers_cfg.ActivationType.SILU
364
+ ),
365
+ num_layers=1,
366
+ attention_block_config=att_config,
367
+ )
368
+
369
+ config = unet_cfg.AutoEncoderConfig(
370
+ in_channels=in_channels,
371
+ latent_channels=latent_channels,
372
+ out_channels=out_channels,
373
+ activation_config=layers_cfg.ActivationConfig(
374
+ layers_cfg.ActivationType.SILU
375
+ ),
376
+ block_out_channels=block_out_channels,
377
+ scaling_factor=scaling_factor,
378
+ layers_per_block=layers_per_block,
379
+ normalization_config=norm_config,
380
+ mid_block_config=mid_block_config,
381
+ )
382
+ return config
@@ -603,7 +603,7 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
603
603
  # Transformer configs.
604
604
  transformer_num_attention_heads = 8
605
605
  transformer_batch_size = batch_size
606
- transformer_cross_attention_dim = 768 # Embedding fomr CLIP model
606
+ transformer_cross_attention_dim = 768 # Embedding from CLIP model
607
607
  transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
608
608
  layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=32
609
609
  )
@@ -645,3 +645,71 @@ def get_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
645
645
  final_norm_config=final_norm_config,
646
646
  final_activation_type=final_activation_type,
647
647
  )
648
+
649
+
650
+ def get_fake_model_config(batch_size: int) -> unet_cfg.DiffusionModelConfig:
651
+ """Get fake configs for the Diffusion model of Stable Diffusion v1.5 for testing.
652
+
653
+ Args:
654
+ batch_size (int): the batch size of input.
655
+
656
+ Retruns:
657
+ The configuration of diffusion model of Stable Diffusion v1.5.
658
+ """
659
+ in_channels = 4
660
+ out_channels = 4
661
+ block_out_channels = [2, 4, 8, 8]
662
+ layers_per_block = 1
663
+ downsample_padding = 1
664
+
665
+ # Residual configs.
666
+ residual_norm_config = layers_cfg.NormalizationConfig(
667
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
668
+ )
669
+ residual_activation_type = layers_cfg.ActivationType.SILU
670
+
671
+ # Transformer configs.
672
+ transformer_num_attention_heads = 1
673
+ transformer_batch_size = batch_size
674
+ transformer_cross_attention_dim = 4 # Embedding from CLIP model
675
+ transformer_pre_conv_norm_config = layers_cfg.NormalizationConfig(
676
+ layers_cfg.NormalizationType.GROUP_NORM, epsilon=1e-6, group_num=2
677
+ )
678
+ transformer_norm_config = layers_cfg.NormalizationConfig(
679
+ layers_cfg.NormalizationType.LAYER_NORM
680
+ )
681
+ transformer_ff_activation_type = layers_cfg.ActivationType.GE_GLU
682
+
683
+ # Time embedding configs.
684
+ time_embedding_dim = 2
685
+ time_embedding_blocks_dim = 4
686
+
687
+ # Mid block configs.
688
+ mid_block_layers = 1
689
+
690
+ # Finaly layer configs.
691
+ final_norm_config = layers_cfg.NormalizationConfig(
692
+ layers_cfg.NormalizationType.GROUP_NORM, group_num=2
693
+ )
694
+ final_activation_type = layers_cfg.ActivationType.SILU
695
+
696
+ return unet_cfg.DiffusionModelConfig(
697
+ in_channels=in_channels,
698
+ out_channels=out_channels,
699
+ block_out_channels=block_out_channels,
700
+ layers_per_block=layers_per_block,
701
+ downsample_padding=downsample_padding,
702
+ residual_norm_config=residual_norm_config,
703
+ residual_activation_type=residual_activation_type,
704
+ transformer_batch_size=transformer_batch_size,
705
+ transformer_num_attention_heads=transformer_num_attention_heads,
706
+ transformer_cross_attention_dim=transformer_cross_attention_dim,
707
+ transformer_pre_conv_norm_config=transformer_pre_conv_norm_config,
708
+ transformer_norm_config=transformer_norm_config,
709
+ transformer_ff_activation_type=transformer_ff_activation_type,
710
+ mid_block_layers=mid_block_layers,
711
+ time_embedding_dim=time_embedding_dim,
712
+ time_embedding_blocks_dim=time_embedding_blocks_dim,
713
+ final_norm_config=final_norm_config,
714
+ final_activation_type=final_activation_type,
715
+ )
@@ -0,0 +1,105 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # A toy example which has a single-layer transformer block.
16
+ from absl import app
17
+ import ai_edge_torch
18
+ from ai_edge_torch import lowertools
19
+ from ai_edge_torch.generative.examples.test_models import toy_model
20
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import torch
23
+
24
+ KV_CACHE_MAX_LEN = 100
25
+
26
+
27
+ def convert_toy_model(_) -> None:
28
+ """Converts a toy model to tflite."""
29
+ model = toy_model.ToySingleLayerModel(toy_model.get_model_config())
30
+ idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
31
+ input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
32
+ print('running an inference')
33
+ print(
34
+ model.forward(
35
+ idx,
36
+ input_pos,
37
+ )
38
+ )
39
+
40
+ # Convert model to tflite.
41
+ print('converting model to tflite')
42
+ edge_model = ai_edge_torch.convert(
43
+ model,
44
+ (
45
+ idx,
46
+ input_pos,
47
+ ),
48
+ )
49
+ edge_model.export('/tmp/toy_model.tflite')
50
+
51
+
52
+ def _export_stablehlo_mlir(model, args):
53
+ ep = torch.export.export(model, args)
54
+ return lowertools.exported_program_to_mlir_text(ep)
55
+
56
+
57
+ def convert_toy_model_with_kv_cache(_) -> None:
58
+ """Converts a toy model with kv cache to tflite."""
59
+ dump_mlir = False
60
+
61
+ config = toy_model_with_kv_cache.get_model_config()
62
+ model = toy_model_with_kv_cache.ToyModelWithKVCache(config)
63
+ model.eval()
64
+ print('running an inference')
65
+ kv = kv_utils.KVCache.from_model_config(config)
66
+
67
+ tokens, input_pos = toy_model_with_kv_cache.get_sample_prefill_inputs()
68
+ decode_token, decode_input_pos = (
69
+ toy_model_with_kv_cache.get_sample_decode_inputs()
70
+ )
71
+ print(model.forward(tokens, input_pos, kv))
72
+
73
+ if dump_mlir:
74
+ mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
75
+ with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
76
+ f.write(mlir_text)
77
+
78
+ # Convert model to tflite with 2 signatures (prefill + decode).
79
+ print('converting toy model to tflite with 2 signatures (prefill + decode)')
80
+ edge_model = (
81
+ ai_edge_torch.signature(
82
+ 'prefill',
83
+ model,
84
+ sample_kwargs={
85
+ 'tokens': tokens,
86
+ 'input_pos': input_pos,
87
+ 'kv_cache': kv,
88
+ },
89
+ )
90
+ .signature(
91
+ 'decode',
92
+ model,
93
+ sample_kwargs={
94
+ 'tokens': decode_token,
95
+ 'input_pos': decode_input_pos,
96
+ 'kv_cache': kv,
97
+ },
98
+ )
99
+ .convert()
100
+ )
101
+ edge_model.export('/tmp/toy_external_kv_cache.tflite')
102
+
103
+
104
+ if __name__ == '__main__':
105
+ app.run(convert_toy_model)
@@ -15,13 +15,12 @@
15
15
  # A toy example which has a single-layer transformer block.
16
16
  from typing import Tuple
17
17
 
18
- import ai_edge_torch
18
+ from ai_edge_torch.generative.layers import builder
19
19
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
20
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
21
- import ai_edge_torch.generative.layers.builder as builder
22
21
  import ai_edge_torch.generative.layers.model_config as cfg
23
22
  import torch
24
- import torch.nn as nn
23
+ from torch import nn
25
24
 
26
25
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
27
26
  KV_CACHE_MAX_LEN = 100
@@ -149,31 +148,3 @@ def get_model_config() -> cfg.ModelConfig:
149
148
  final_norm_config=norm_config,
150
149
  )
151
150
  return config
152
-
153
-
154
- def define_and_run() -> None:
155
- model = ToySingleLayerModel(get_model_config())
156
- idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
157
- input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
158
- print('running an inference')
159
- print(
160
- model.forward(
161
- idx,
162
- input_pos,
163
- )
164
- )
165
-
166
- # Convert model to tflite.
167
- print('converting model to tflite')
168
- edge_model = ai_edge_torch.convert(
169
- model,
170
- (
171
- idx,
172
- input_pos,
173
- ),
174
- )
175
- edge_model.export('/tmp/toy_model.tflite')
176
-
177
-
178
- if __name__ == '__main__':
179
- define_and_run()
@@ -17,15 +17,14 @@
17
17
 
18
18
  from typing import Tuple
19
19
 
20
- import ai_edge_torch
21
- from ai_edge_torch import lowertools
20
+ from absl import app
22
21
  from ai_edge_torch.generative.layers import attention
23
22
  from ai_edge_torch.generative.layers import builder
24
23
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
25
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
25
  import ai_edge_torch.generative.layers.model_config as cfg
27
26
  import torch
28
- import torch.nn as nn
27
+ from torch import nn
29
28
 
30
29
  RoPECache = Tuple[torch.Tensor, torch.Tensor]
31
30
 
@@ -87,11 +86,6 @@ class ToyModelWithKVCache(torch.nn.Module):
87
86
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
88
87
 
89
88
 
90
- def _export_stablehlo_mlir(model, args):
91
- ep = torch.export.export(model, args)
92
- return lowertools.exported_program_to_mlir_text(ep)
93
-
94
-
95
89
  def get_model_config() -> cfg.ModelConfig:
96
90
  attn_config = cfg.AttentionConfig(
97
91
  num_heads=32,
@@ -133,51 +127,3 @@ def get_sample_decode_inputs() -> Tuple[torch.Tensor, torch.Tensor]:
133
127
  tokens = torch.tensor([[1]], dtype=torch.int)
134
128
  input_pos = torch.tensor([10])
135
129
  return tokens, input_pos
136
-
137
-
138
- def define_and_run() -> None:
139
- dump_mlir = False
140
-
141
- config = get_model_config()
142
- model = ToyModelWithExternalKV(config)
143
- model.eval()
144
- print('running an inference')
145
- kv = kv_utils.KVCache.from_model_config(config)
146
-
147
- tokens, input_pos = get_sample_prefill_inputs()
148
- decode_token, decode_input_pos = get_sample_decode_inputs()
149
- print(model.forward(tokens, input_pos, kv))
150
-
151
- if dump_mlir:
152
- mlir_text = _export_stablehlo_mlir(model, (tokens, input_pos, kv))
153
- with open('/tmp/toy_model_with_external_kv.stablehlo.mlir', 'w') as f:
154
- f.write(mlir_text)
155
-
156
- # Convert model to tflite with 2 signatures (prefill + decode).
157
- print('converting toy model to tflite with 2 signatures (prefill + decode)')
158
- edge_model = (
159
- ai_edge_torch.signature(
160
- 'prefill',
161
- model,
162
- sample_kwargs={
163
- 'tokens': tokens,
164
- 'input_pos': input_pos,
165
- 'kv_cache': kv,
166
- },
167
- )
168
- .signature(
169
- 'decode',
170
- model,
171
- sample_kwargs={
172
- 'tokens': decode_token,
173
- 'input_pos': decode_input_pos,
174
- 'kv_cache': kv,
175
- },
176
- )
177
- .convert()
178
- )
179
- edge_model.export('/tmp/toy_external_kv_cache.tflite')
180
-
181
-
182
- if __name__ == '__main__':
183
- define_and_run()
@@ -189,7 +189,7 @@ def group_norm_with_hlfb(
189
189
  name="odml.group_norm",
190
190
  attr={
191
191
  "num_groups": num_groups,
192
- "eps": eps,
192
+ "epsilon": eps,
193
193
  "reduction_axes": 3,
194
194
  "channel_axis": 3,
195
195
  },
@@ -226,7 +226,7 @@ def layer_norm_with_hlfb(
226
226
  """
227
227
  builder = StableHLOCompositeBuilder(
228
228
  name="odml.group_norm",
229
- attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
229
+ attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
230
  )
231
231
  x, w, b = builder.mark_inputs(x, w, b)
232
232
  if use_input_shape:
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  from ai_edge_torch.generative.layers.attention import CrossAttention
19
19
  from ai_edge_torch.generative.layers.attention import SelfAttention
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
416
416
  time_emb: Optional[torch.Tensor] = None,
417
417
  context_tensor: Optional[torch.Tensor] = None,
418
418
  output_hidden_states: bool = False,
419
- ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
419
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
420
420
  """Forward function of the DownEncoderBlock2D.
421
421
 
422
422
  Args:
@@ -21,7 +21,11 @@ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
23
  from ai_edge_torch.generative.examples.phi import phi2
24
+ from ai_edge_torch.generative.examples.phi import phi3
24
25
  from ai_edge_torch.generative.examples.smollm import smollm
26
+ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
27
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
28
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
25
29
  from ai_edge_torch.generative.layers import kv_cache
26
30
  from ai_edge_torch.generative.test import utils as test_utils
27
31
  import numpy as np
@@ -109,6 +113,17 @@ class TestModelConversion(googletest.TestCase):
109
113
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
110
114
  )
111
115
 
116
+ @googletest.skipIf(
117
+ ai_edge_config.Config.use_torch_xla,
118
+ reason="tests with custom ops are not supported on oss",
119
+ )
120
+ def test_phi3(self):
121
+ config = phi3.get_fake_model_config()
122
+ pytorch_model = phi3.Phi3_5Mini(config).eval()
123
+ self._test_model(
124
+ config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
125
+ )
126
+
112
127
  @googletest.skipIf(
113
128
  ai_edge_config.Config.use_torch_xla,
114
129
  reason="tests with custom ops are not supported on oss",
@@ -127,6 +142,110 @@ class TestModelConversion(googletest.TestCase):
127
142
  pytorch_model = openelm.OpenELM(config).eval()
128
143
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
129
144
 
145
+ @googletest.skipIf(
146
+ ai_edge_config.Config.use_torch_xla,
147
+ reason="tests with custom ops are not supported on oss",
148
+ )
149
+ def test_stable_diffusion_clip(self):
150
+ config = sd_clip.get_fake_model_config()
151
+ prompt_tokens = torch.from_numpy(
152
+ np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
153
+ )
154
+
155
+ pytorch_model = sd_clip.CLIP(config).eval()
156
+ torch_output = pytorch_model(prompt_tokens)
157
+
158
+ edge_model = ai_edge_torch.signature(
159
+ "encode", pytorch_model, (prompt_tokens,)
160
+ ).convert()
161
+ edge_model.set_interpreter_builder(
162
+ self._interpreter_builder(edge_model.tflite_model())
163
+ )
164
+ edge_output = edge_model(
165
+ prompt_tokens.numpy(),
166
+ signature_name="encode",
167
+ )
168
+ self.assertTrue(
169
+ np.allclose(
170
+ edge_output,
171
+ torch_output.detach().numpy(),
172
+ atol=1e-4,
173
+ rtol=1e-5,
174
+ )
175
+ )
176
+
177
+ @googletest.skipIf(
178
+ ai_edge_config.Config.use_torch_xla,
179
+ reason="tests with custom ops are not supported on oss",
180
+ )
181
+ def test_stable_diffusion_diffusion(self):
182
+ config = sd_diffusion.get_fake_model_config(2)
183
+ latents = torch.from_numpy(
184
+ np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
185
+ )
186
+ context = torch.from_numpy(
187
+ np.random.normal(size=(2, 4, 4)).astype(np.float32)
188
+ )
189
+ time_embedding = torch.from_numpy(
190
+ np.random.normal(size=(2, 2)).astype(np.float32)
191
+ )
192
+
193
+ pytorch_model = sd_diffusion.Diffusion(config).eval()
194
+ torch_output = pytorch_model(latents, context, time_embedding)
195
+
196
+ edge_model = ai_edge_torch.signature(
197
+ "diffusion", pytorch_model, (latents, context, time_embedding)
198
+ ).convert()
199
+ edge_model.set_interpreter_builder(
200
+ self._interpreter_builder(edge_model.tflite_model())
201
+ )
202
+ edge_output = edge_model(
203
+ latents.numpy(),
204
+ context.numpy(),
205
+ time_embedding.numpy(),
206
+ signature_name="diffusion",
207
+ )
208
+ self.assertTrue(
209
+ np.allclose(
210
+ edge_output,
211
+ torch_output.detach().numpy(),
212
+ atol=1e-4,
213
+ rtol=1e-5,
214
+ )
215
+ )
216
+
217
+ @googletest.skipIf(
218
+ ai_edge_config.Config.use_torch_xla,
219
+ reason="tests with custom ops are not supported on oss",
220
+ )
221
+ def test_stable_diffusion_decoder(self):
222
+ config = sd_decoder.get_fake_model_config()
223
+ latents = torch.from_numpy(
224
+ np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
225
+ )
226
+
227
+ pytorch_model = sd_decoder.Decoder(config).eval()
228
+ torch_output = pytorch_model(latents)
229
+
230
+ edge_model = ai_edge_torch.signature(
231
+ "decode", pytorch_model, (latents,)
232
+ ).convert()
233
+ edge_model.set_interpreter_builder(
234
+ self._interpreter_builder(edge_model.tflite_model())
235
+ )
236
+ edge_output = edge_model(
237
+ latents.numpy(),
238
+ signature_name="decode",
239
+ )
240
+ self.assertTrue(
241
+ np.allclose(
242
+ edge_output,
243
+ torch_output.detach().numpy(),
244
+ atol=1e-4,
245
+ rtol=1e-5,
246
+ )
247
+ )
248
+
130
249
 
131
250
  if __name__ == "__main__":
132
251
  googletest.main()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240924"
16
+ __version__ = "0.3.0.dev20240925"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240924
3
+ Version: 0.3.0.dev20240925
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=sQUcRP5rShDk3vfblz87j26JciN6PV8S8DJkiiZP5o8,706
6
+ ai_edge_torch/version.py,sha256=UXj1-90S3RDoHwYSmy9VdMC0Sm3EHt9ESLZbi3hnWus,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -51,19 +51,22 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
51
51
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
52
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
53
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
54
55
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
55
56
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
56
- ai_edge_torch/generative/examples/phi/verify.py,sha256=SwPyRjiupD4AsmWW_7FDcMSWaNRmDBu6uVFcBQRoM40,2146
57
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=DIDzpG8DZkWDcWsAVkcxzxIC3U3352uVI3zMoYZD16U,9554
58
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=5pQ0Bt8vGl8uTpkgXvOx8G7_rju0Gi8mIEr5NtRSAbs,2145
59
+ ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=o1UTqpimkeX3MDjgdG1QTQkoZHvCEnGClA0J0WB3wJ4,2328
57
60
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
61
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
59
62
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
60
63
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
61
64
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
62
65
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
63
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
66
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=2RMi5UmfMT4Ep68ZLJsqF-fMvEumNVkIwqtsRli9HhA,6068
64
67
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
65
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
66
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
68
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ZTRD56e8MsdGPJr7vpLa4Ju_BFw_b-FUgXgd-SO5MBw,15665
69
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6FAnevL8ZfCK2YCSPivarUH0Z8wGKSmnPpJNC0OI5A8,33680
67
70
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
68
71
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
69
72
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -78,8 +81,9 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F7
78
81
  ai_edge_torch/generative/examples/t5/t5.py,sha256=OZ67knK-UB1dBjxydG-Jwkp0Z3FzOCqGPTdg5aBFu4w,21328
79
82
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
80
83
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
82
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
84
+ ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
85
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LTuzres5DHmrMT6U9rCrGf6vmR9SmopmB8sO6Cd2NxQ,5255
86
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xDYTh4m3vBEb6r3_ERhmj5qILW7YdVDAnZ-fitgYONg,4450
83
87
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
88
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
85
89
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
@@ -93,11 +97,11 @@ ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk
93
97
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
94
98
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
95
99
  ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
96
- ai_edge_torch/generative/layers/normalization.py,sha256=LDczSHujMgo1WV8IhYVQe-egPkaBEmWFt8wZQ_tgshg,6991
100
+ ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
97
101
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
98
102
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
99
103
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
100
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
104
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQzmKfHdRWaaE1EbaMI4Gs,27540
101
105
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
102
106
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
103
107
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -111,7 +115,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
111
115
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
112
116
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
113
117
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
114
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
118
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IzW2HjXS2-zePZM-qEuXL4zclnGvYsNw-6tuDSeNna4,8163
115
119
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
116
120
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
117
121
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -166,8 +170,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
166
170
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
167
171
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
168
172
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
169
- ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
- ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/METADATA,sha256=BotYlw1pMxClnHOi8rSb5v6jX0zE7EqUo8b11xvqEII,1897
171
- ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
- ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
- ai_edge_torch_nightly-0.3.0.dev20240924.dist-info/RECORD,,
173
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
174
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/METADATA,sha256=5KsshdZ4-3X193HkoO2ukceyDEdWGvb8ZEMcw88qt7k,1897
175
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
176
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
177
+ ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/RECORD,,