ai-edge-torch-nightly 0.3.0.dev20240813__py3-none-any.whl → 0.3.0.dev20240815__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (27) hide show
  1. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +2 -2
  2. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +2 -2
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +67 -0
  4. ai_edge_torch/generative/examples/gemma/gemma.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +250 -0
  6. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -2
  7. ai_edge_torch/generative/examples/t5/t5.py +4 -4
  8. ai_edge_torch/generative/examples/t5/t5_attention.py +3 -3
  9. ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
  10. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
  11. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
  12. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
  13. ai_edge_torch/generative/layers/attention.py +12 -5
  14. ai_edge_torch/generative/layers/attention_utils.py +30 -0
  15. ai_edge_torch/generative/layers/builder.py +5 -0
  16. ai_edge_torch/generative/layers/feed_forward.py +15 -3
  17. ai_edge_torch/generative/layers/model_config.py +35 -13
  18. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +25 -9
  19. ai_edge_torch/generative/test/test_model_conversion.py +29 -1
  20. ai_edge_torch/generative/utilities/loader.py +29 -7
  21. ai_edge_torch/generative/utilities/t5_loader.py +8 -8
  22. ai_edge_torch/version.py +1 -1
  23. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240815.dist-info}/METADATA +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240815.dist-info}/RECORD +27 -25
  25. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240815.dist-info}/LICENSE +0 -0
  26. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240815.dist-info}/WHEEL +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20240813.dist-info → ai_edge_torch_nightly-0.3.0.dev20240815.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
40
40
  attn_value_proj="model.layers.{}.self_attn.v_proj",
41
41
  attn_output_proj="model.layers.{}.self_attn.o_proj",
42
42
  pre_attn_norm="model.layers.{}.input_layernorm",
43
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
43
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
44
44
  embedding="model.embed_tokens",
45
45
  final_norm="model.norm",
46
46
  lm_head=None,
@@ -150,7 +150,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
150
  attn_config=attn_config,
151
151
  ff_config=ff_config,
152
152
  pre_attention_norm_config=norm_config,
153
- pre_ff_norm_config=norm_config,
153
+ post_attention_norm_config=norm_config,
154
154
  final_norm_config=norm_config,
155
155
  parallel_residual=False,
156
156
  lm_head_use_bias=False,
@@ -41,7 +41,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
41
41
  attn_value_proj="model.layers.{}.self_attn.v_proj",
42
42
  attn_output_proj="model.layers.{}.self_attn.o_proj",
43
43
  pre_attn_norm="model.layers.{}.input_layernorm",
44
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
44
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
45
45
  embedding="model.embed_tokens",
46
46
  final_norm="model.norm",
47
47
  lm_head="lm_head",
@@ -142,7 +142,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
142
142
  attn_config=attn_config,
143
143
  ff_config=ff_config,
144
144
  pre_attention_norm_config=norm_config,
145
- pre_ff_norm_config=norm_config,
145
+ post_attention_norm_config=norm_config,
146
146
  final_norm_config=norm_config,
147
147
  enable_hlfb=True,
148
148
  )
@@ -0,0 +1,67 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import ai_edge_torch
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.quantize import quant_recipes
22
+ import torch
23
+
24
+
25
+ def convert_gemma_to_tflite(
26
+ checkpoint_path: str,
27
+ prefill_seq_len: int = 512,
28
+ kv_cache_max_len: int = 1024,
29
+ quantize: bool = True,
30
+ ):
31
+ """Converting a Gemma 2 2B model to multi-signature
32
+ tflite model.
33
+
34
+ Args:
35
+ checkpoint_path (str): The filepath to the model checkpoint, or directory holding the checkpoint.
36
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
37
+ Defaults to 512.
38
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
39
+ including both prefill and decode. Defaults to 1024.
40
+ quantize (bool, optional): Whether the model should be quanized.
41
+ Defaults to True.
42
+ """
43
+ pytorch_model = gemma2.build_2b_model(
44
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
+ )
46
+ # Tensors used to trace the model graph during conversion.
47
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
48
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
49
+ decode_token = torch.tensor([[0]], dtype=torch.long)
50
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
51
+
52
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
53
+ edge_model = (
54
+ ai_edge_torch.signature(
55
+ 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
56
+ )
57
+ .signature('decode', pytorch_model, (decode_token, decode_input_pos))
58
+ .convert(quant_config=quant_config)
59
+ )
60
+ edge_model.export(
61
+ f'/tmp/gemma2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
62
+ )
63
+
64
+
65
+ if __name__ == '__main__':
66
+ checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/gemma2-2b')
67
+ convert_gemma_to_tflite(checkpoint_path)
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
35
  attn_value_proj="model.layers.{}.self_attn.v_proj",
36
36
  attn_output_proj="model.layers.{}.self_attn.o_proj",
37
37
  pre_attn_norm="model.layers.{}.input_layernorm",
38
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
39
  embedding="model.embed_tokens",
40
40
  final_norm="model.norm",
41
41
  lm_head=None,
@@ -138,7 +138,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
138
138
  attn_config=attn_config,
139
139
  ff_config=ff_config,
140
140
  pre_attention_norm_config=norm_config,
141
- pre_ff_norm_config=norm_config,
141
+ post_attention_norm_config=norm_config,
142
142
  final_norm_config=norm_config,
143
143
  parallel_residual=False,
144
144
  lm_head_use_bias=False,
@@ -160,6 +160,7 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
160
160
  # since embedding and lm-head use the same weight, we need to set strict
161
161
  # to False.
162
162
  loader.load(model, strict=False)
163
+ model.eval()
163
164
  return model
164
165
 
165
166
 
@@ -0,0 +1,250 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Example of building the Gemma2 2B model.
16
+
17
+ import os
18
+ from pathlib import Path
19
+ from typing import Optional, Tuple
20
+
21
+ from ai_edge_torch.generative.layers.attention import TransformerBlock
22
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
+ import ai_edge_torch.generative.layers.builder as builder
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
+ ff_up_proj="model.layers.{}.mlp.up_proj",
32
+ ff_down_proj="model.layers.{}.mlp.down_proj",
33
+ ff_gate_proj="model.layers.{}.mlp.gate_proj",
34
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
35
+ attn_output_proj="model.layers.{}.self_attn.o_proj",
36
+ pre_attn_norm="model.layers.{}.input_layernorm",
37
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
38
+ pre_ff_norm="model.layers.{}.pre_feedforward_layernorm",
39
+ post_ff_norm="model.layers.{}.post_feedforward_layernorm",
40
+ embedding="embedder",
41
+ final_norm="model.norm",
42
+ lm_head=None,
43
+ )
44
+
45
+
46
+ class Gemma2Block(TransformerBlock):
47
+
48
+ def forward(
49
+ self,
50
+ x: torch.Tensor,
51
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
52
+ mask: Optional[torch.Tensor] = None,
53
+ input_pos: Optional[torch.Tensor] = None,
54
+ ) -> torch.Tensor:
55
+ """Forward function of the Gemma2Block.
56
+
57
+ Exactly the same as TransformerBlock but we call the post-attention norm
58
+ immediately after attention and not after the residual pointwise addition.
59
+
60
+ Args:
61
+ x (torch.Tensor): the input tensor.
62
+ rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
63
+ mask (torch.Tensor): the optional mask tensor.
64
+ input_pos (torch.Tensor): the optional input position tensor.
65
+
66
+ Returns:
67
+ output activation from this transformer block.
68
+ """
69
+
70
+ x_norm = self.pre_atten_norm(x)
71
+ attn_out = self.atten_func(x_norm, rope, mask, input_pos)
72
+ attn_out_norm = self.post_atten_norm(attn_out)
73
+ x = x + attn_out_norm
74
+ output = x + self.ff(x)
75
+ return output
76
+
77
+
78
+ class Gemma2(nn.Module):
79
+
80
+ def __init__(self, config: cfg.ModelConfig):
81
+ super().__init__()
82
+
83
+ self.config = config
84
+ # Construct model layers.
85
+ self.tok_embedding = nn.Embedding(
86
+ config.vocab_size, config.embedding_dim, padding_idx=0
87
+ )
88
+ self.lm_head = nn.Linear(
89
+ config.embedding_dim,
90
+ config.vocab_size,
91
+ bias=config.lm_head_use_bias,
92
+ )
93
+ # Gemma re-uses the embedding as the head projection layer.
94
+ self.lm_head.weight.data = self.tok_embedding.weight.data
95
+ self.transformer_blocks = nn.ModuleList(
96
+ Gemma2Block(config) for _ in range(config.num_layers)
97
+ )
98
+ self.final_norm = builder.build_norm(
99
+ config.embedding_dim,
100
+ config.final_norm_config,
101
+ )
102
+ self.rope_cache = attn_utils.build_rope_cache(
103
+ size=config.kv_cache_max,
104
+ dim=int(
105
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
106
+ ),
107
+ base=10_000,
108
+ condense_ratio=1,
109
+ dtype=torch.float32,
110
+ device=torch.device("cpu"),
111
+ )
112
+ self.mask_cache = attn_utils.build_causal_mask_cache(
113
+ size=config.kv_cache_max,
114
+ dtype=torch.float32,
115
+ device=torch.device("cpu"),
116
+ )
117
+
118
+ self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
119
+ size=config.kv_cache_max,
120
+ window_size=self.config.attn_config.sliding_window_size,
121
+ dtype=torch.float32,
122
+ device=torch.device("cpu"),
123
+ )
124
+
125
+ self.config = config
126
+
127
+ def get_attention_mask(
128
+ self, idx: int, input_pos: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ if self.config.attn_config.attn_types:
131
+ if (
132
+ self.config.attn_config.attn_types[idx]
133
+ == cfg.AttentionType.LOCAL_SLIDING
134
+ ):
135
+ return self.sliding_window_mask_cache.index_select(2, input_pos)
136
+
137
+ return self.mask_cache.index_select(2, input_pos)
138
+
139
+ @torch.inference_mode
140
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
141
+ B, T = idx.size()
142
+ assert self.config.max_seq_len >= T, (
143
+ f"Cannot forward sequence of length {T}, max seq length is only"
144
+ f" {self.config.max_seq_len}"
145
+ )
146
+
147
+ cos, sin = self.rope_cache
148
+ cos = cos.index_select(0, input_pos)
149
+ sin = sin.index_select(0, input_pos)
150
+
151
+ # token embeddings of shape (b, t, n_embd)
152
+ x = self.tok_embedding(idx)
153
+ x = x * (self.config.embedding_dim**0.5)
154
+
155
+ for i, block in enumerate(self.transformer_blocks):
156
+ mask = self.get_attention_mask(i, input_pos)
157
+ x = block(x, (cos, sin), mask, input_pos)
158
+
159
+ x = self.final_norm(x)
160
+ res = self.lm_head(x) # (b, t, vocab_size)
161
+ if self.config.final_logit_softcap is not None:
162
+ res = res / self.config.final_logit_softcap
163
+ res = torch.tanh(res)
164
+ res = res * self.config.final_logit_softcap
165
+ return res
166
+
167
+
168
+ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
169
+ attn_config = cfg.AttentionConfig(
170
+ num_heads=8,
171
+ head_dim=256,
172
+ num_query_groups=4,
173
+ rotary_percentage=1.0,
174
+ qkv_transpose_before_split=True,
175
+ logit_softcap=50.0,
176
+ sliding_window_size=4096,
177
+ attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
178
+ * 13,
179
+ )
180
+
181
+ norm_config = cfg.NormalizationConfig(
182
+ type=cfg.NormalizationType.RMS_NORM,
183
+ epsilon=1e-6,
184
+ zero_centered=True,
185
+ )
186
+ ff_config = cfg.FeedForwardConfig(
187
+ type=cfg.FeedForwardType.GATED,
188
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
189
+ intermediate_size=9216,
190
+ pre_ff_norm_config=norm_config,
191
+ post_ff_norm_config=norm_config,
192
+ )
193
+ config = cfg.ModelConfig(
194
+ vocab_size=256000,
195
+ num_layers=26,
196
+ max_seq_len=8192,
197
+ embedding_dim=2304,
198
+ kv_cache_max_len=kv_cache_max_len,
199
+ attn_config=attn_config,
200
+ ff_config=ff_config,
201
+ pre_attention_norm_config=norm_config,
202
+ post_attention_norm_config=norm_config,
203
+ final_norm_config=norm_config,
204
+ parallel_residual=False,
205
+ lm_head_use_bias=False,
206
+ enable_hlfb=False,
207
+ final_logit_softcap=30.0,
208
+ )
209
+ return config
210
+
211
+
212
+ def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
213
+ config = get_model_config_2b()
214
+ config.num_layers = 2
215
+ return config
216
+
217
+
218
+ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
219
+ config = get_model_config_2b(**kwargs)
220
+ model = Gemma2(config)
221
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
222
+ # since embedding and lm-head use the same weight, we need to set strict
223
+ # to False.
224
+ loader.load(model, strict=False)
225
+ model.eval()
226
+ return model
227
+
228
+
229
+ def define_and_run_2b() -> None:
230
+ current_dir = Path(__file__).parent.resolve()
231
+ gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
232
+ print("Running GEMMA 2")
233
+ kv_cache_max_len = 1024
234
+ checkpoint_path = os.path.join(Path.home(), "Downloads/llm_data/gemma2-2b")
235
+ model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
236
+ toks = torch.from_numpy(
237
+ np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
238
+ )
239
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
240
+ tokens[0, :9] = toks
241
+ input_pos = torch.arange(0, kv_cache_max_len)
242
+ out = model.forward(tokens, input_pos)
243
+ out_final = out[0, 8, :]
244
+ assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
245
+ print(out)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ torch.set_printoptions(sci_mode=True)
250
+ define_and_run_2b()
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
35
  pre_attn_norm=(
36
36
  "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1"
37
37
  ),
38
- pre_ff_norm=(
38
+ post_attn_norm=(
39
39
  "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2"
40
40
  ),
41
41
  embedding=(
@@ -120,7 +120,7 @@ def get_model_config() -> cfg.ModelConfig:
120
120
  attn_config=attn_config,
121
121
  ff_config=ff_config,
122
122
  pre_attention_norm_config=norm_config,
123
- pre_ff_norm_config=norm_config,
123
+ post_attention_norm_config=norm_config,
124
124
  final_norm_config=norm_config,
125
125
  enable_hlfb=True,
126
126
  )
@@ -38,7 +38,7 @@ ENCDEC_TENSOR_NAMES = {
38
38
  "{prefix}.block.0.layer.0.SelfAttention.relative_attention_bias"
39
39
  ),
40
40
  "pre_attn_norm": "{prefix}.block.{}.layer.0.layer_norm",
41
- "pre_ff_norm": "{prefix}.block.{}.layer.1.layer_norm",
41
+ "post_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
42
42
  "final_norm": "{prefix}.final_layer_norm",
43
43
  }
44
44
 
@@ -396,7 +396,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
396
396
  relative_attention=True,
397
397
  ff_config=ff_config,
398
398
  pre_attention_norm_config=norm_config,
399
- pre_ff_norm_config=norm_config,
399
+ post_attention_norm_config=norm_config,
400
400
  final_norm_config=norm_config,
401
401
  parallel_residual=False,
402
402
  lm_head_use_bias=False,
@@ -419,7 +419,7 @@ def build_t5_model(checkpoint_path: str) -> nn.Module:
419
419
  "cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
420
420
  "cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
421
421
  # In the decoder, the FF is layer 2 in the Transformer block
422
- "pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
422
+ "post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
423
423
  # In the decoder, the cross attention is layer 1 in the Transformer block
424
424
  "pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
425
425
  }
@@ -475,7 +475,7 @@ def build_t5_decoder_model(
475
475
  "cross_attn_value_proj": "{prefix}.block.{}.layer.1.EncDecAttention.v",
476
476
  "cross_attn_output_proj": "{prefix}.block.{}.layer.1.EncDecAttention.o",
477
477
  # In the decoder, the FF is layer 2 in the Transformer block
478
- "pre_ff_norm": "{prefix}.block.{}.layer.2.layer_norm",
478
+ "post_attn_norm": "{prefix}.block.{}.layer.2.layer_norm",
479
479
  # In the decoder, the cross attention is layer 1 in the Transformer block
480
480
  "pre_cross_attn_norm": "{prefix}.block.{}.layer.1.layer_norm",
481
481
  }
@@ -68,8 +68,8 @@ class EncoderDecoderBlock(nn.Module):
68
68
  else:
69
69
  self.cross_atten_func = None
70
70
 
71
- self.pre_ff_norm = builder.build_norm(
72
- config.embedding_dim, config.pre_ff_norm_config
71
+ self.post_atten_norm = builder.build_norm(
72
+ config.embedding_dim, config.post_attention_norm_config
73
73
  )
74
74
  self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
75
75
  self.config = config
@@ -118,7 +118,7 @@ class EncoderDecoderBlock(nn.Module):
118
118
  )
119
119
  attn_out = hidden_states + attn_out
120
120
 
121
- forwarded = self.pre_ff_norm(attn_out)
121
+ forwarded = self.post_atten_norm(attn_out)
122
122
  forwarded = self.ff(forwarded)
123
123
  hidden_states = attn_out + forwarded
124
124
 
@@ -93,7 +93,7 @@ def get_model_config() -> cfg.ModelConfig:
93
93
  attn_config=attn_config,
94
94
  ff_config=ff_config,
95
95
  pre_attention_norm_config=norm_config,
96
- pre_ff_norm_config=norm_config,
96
+ post_attention_norm_config=norm_config,
97
97
  final_norm_config=norm_config,
98
98
  )
99
99
  return config
@@ -107,7 +107,7 @@ def get_model_config() -> cfg.ModelConfig:
107
107
  attn_config=attn_config,
108
108
  ff_config=ff_config,
109
109
  pre_attention_norm_config=norm_config,
110
- pre_ff_norm_config=norm_config,
110
+ post_attention_norm_config=norm_config,
111
111
  final_norm_config=norm_config,
112
112
  enable_hlfb=True,
113
113
  )
@@ -94,7 +94,7 @@ def get_model_config() -> cfg.ModelConfig:
94
94
  attn_config=attn_config,
95
95
  ff_config=ff_config,
96
96
  pre_attention_norm_config=norm_config,
97
- pre_ff_norm_config=norm_config,
97
+ post_attention_norm_config=norm_config,
98
98
  final_norm_config=norm_config,
99
99
  enable_hlfb=True,
100
100
  )
@@ -35,7 +35,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
35
  attn_value_proj="model.layers.{}.self_attn.v_proj",
36
36
  attn_output_proj="model.layers.{}.self_attn.o_proj",
37
37
  pre_attn_norm="model.layers.{}.input_layernorm",
38
- pre_ff_norm="model.layers.{}.post_attention_layernorm",
38
+ post_attn_norm="model.layers.{}.post_attention_layernorm",
39
39
  embedding="model.embed_tokens",
40
40
  final_norm="model.norm",
41
41
  lm_head="lm_head",
@@ -130,7 +130,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
130
  attn_config=attn_config,
131
131
  ff_config=ff_config,
132
132
  pre_attention_norm_config=norm_config,
133
- pre_ff_norm_config=norm_config,
133
+ post_attention_norm_config=norm_config,
134
134
  final_norm_config=norm_config,
135
135
  enable_hlfb=True,
136
136
  )
@@ -74,8 +74,8 @@ class TransformerBlock(nn.Module):
74
74
  config.kv_cache_max,
75
75
  config.enable_hlfb,
76
76
  )
77
- self.pre_ff_norm = builder.build_norm(
78
- config.embedding_dim, config.pre_ff_norm_config
77
+ self.post_atten_norm = builder.build_norm(
78
+ config.embedding_dim, config.post_attention_norm_config
79
79
  )
80
80
  self.ff = builder.build_ff(config.embedding_dim, config.ff_config)
81
81
  self.config = config
@@ -108,7 +108,7 @@ class TransformerBlock(nn.Module):
108
108
  x_norm = self.pre_atten_norm(x)
109
109
  attn_out = self.atten_func(x_norm, rope, mask, input_pos)
110
110
  x = x + attn_out
111
- x_norm = self.pre_ff_norm(x)
111
+ x_norm = self.post_atten_norm(x)
112
112
  output = x + self.ff(x_norm)
113
113
 
114
114
  return output
@@ -228,8 +228,15 @@ class CausalSelfAttention(nn.Module):
228
228
  # TODO(haoliang): Handle when execeeding max sequence length.
229
229
  k, v = self.kv_cache.update_cache(input_pos, k, v)
230
230
 
231
- y = self.sdpa_func(q, k, v, self.config.head_dim, mask=mask)
232
- y = y.reshape(B, T, E)
231
+ y = self.sdpa_func(
232
+ q,
233
+ k,
234
+ v,
235
+ self.config.head_dim,
236
+ mask=mask,
237
+ softcap=self.config.logit_softcap,
238
+ )
239
+ y = y.reshape(B, T, -1)
233
240
 
234
241
  # Compute the output projection.
235
242
  y = self.output_projection(y)
@@ -74,12 +74,42 @@ def build_causal_mask_cache(
74
74
  Returns:
75
75
  torch.Tensor: Causal attention mask.
76
76
  """
77
+
77
78
  if device is None:
78
79
  device = torch.device('cpu')
79
80
  mask = torch.full((size, size), float('-inf'), dtype=dtype, device=device)
80
81
  return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
81
82
 
82
83
 
84
+ def build_sliding_window_mask_cache(
85
+ size: int,
86
+ window_size: int,
87
+ dtype: torch.dtype = torch.float32,
88
+ device: torch.device = None,
89
+ ) -> torch.Tensor:
90
+ """Build a cache for a sliding window mask.
91
+
92
+ Args:
93
+ size (int): The size of the built mask cache.
94
+ window_size (int): The window size that is "seen" by a token.
95
+ dtype (torch.dtype, optional): Output tensor's data type. Defaults to
96
+ torch.float32.
97
+ device (torch.device, optional): Output tensor's data type. Defaults to
98
+ None in which case "cpu" is used.
99
+
100
+ Returns:
101
+ torch.Tensor: Causal attention mask.
102
+ """
103
+
104
+ mask = build_causal_mask_cache(size, dtype, device)
105
+ all_ones = torch.ones_like(mask)
106
+ window_size = min(size, window_size)
107
+ sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
108
+ all_ones, window_size - 1
109
+ )
110
+ return torch.where(sliding_mask == 1, mask, -2.3819763e38)
111
+
112
+
83
113
  def relative_position_bucket(
84
114
  relative_position: torch.Tensor,
85
115
  bidirectional: bool,
@@ -89,11 +89,16 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
89
89
 
90
90
  activation = get_activation(config.activation)
91
91
 
92
+ pre_ff_norm = build_norm(dim, config.pre_ff_norm_config)
93
+ post_ff_norm = build_norm(dim, config.post_ff_norm_config)
94
+
92
95
  return ff_module(
93
96
  dim=dim,
94
97
  hidden_dim=config.intermediate_size,
95
98
  activation=activation,
96
99
  use_bias=config.use_bias,
100
+ pre_ff_norm=pre_ff_norm,
101
+ post_ff_norm=post_ff_norm,
97
102
  )
98
103
 
99
104
 
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
  # Common building blocks for FeedForward layers.
16
16
 
17
- from typing import Callable
17
+ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
20
  from torch import nn
@@ -30,6 +30,8 @@ class SequentialFeedForward(nn.Module):
30
30
  hidden_dim: int,
31
31
  activation: Callable[[torch.Tensor], torch.Tensor],
32
32
  use_bias=False,
33
+ pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
34
+ post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
33
35
  ):
34
36
  """Init function for feedforward layer.
35
37
 
@@ -41,6 +43,8 @@ class SequentialFeedForward(nn.Module):
41
43
  self.act = activation
42
44
  self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
43
45
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
46
+ self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
47
+ self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
44
48
 
45
49
  def forward(self, x):
46
50
  """Forward pass for Feedforward layer.
@@ -51,7 +55,9 @@ class SequentialFeedForward(nn.Module):
51
55
  Returns:
52
56
  torch.Tensor: output tensor after feedforward.
53
57
  """
54
- return self.w2(self.act(self.w1(x)))
58
+ x_norm = self.pre_ff_norm(x)
59
+ out = self.w2(self.act(self.w1(x_norm)))
60
+ return self.post_ff_norm(out)
55
61
 
56
62
 
57
63
  class GatedFeedForward(nn.Module):
@@ -66,6 +72,8 @@ class GatedFeedForward(nn.Module):
66
72
  hidden_dim: int,
67
73
  activation: Callable[[torch.Tensor], torch.Tensor],
68
74
  use_bias=False,
75
+ pre_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
76
+ post_ff_norm: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
69
77
  ):
70
78
  """Init function for feedforward layer.
71
79
 
@@ -78,6 +86,8 @@ class GatedFeedForward(nn.Module):
78
86
  self.w1 = nn.Linear(dim, hidden_dim, bias=use_bias)
79
87
  self.w2 = nn.Linear(hidden_dim, dim, bias=use_bias)
80
88
  self.w3 = nn.Linear(dim, hidden_dim, bias=use_bias)
89
+ self.pre_ff_norm = pre_ff_norm if pre_ff_norm else lambda x: x
90
+ self.post_ff_norm = post_ff_norm if post_ff_norm else lambda x: x
81
91
 
82
92
  def forward(self, x):
83
93
  """Forward pass for Feedforward layer.
@@ -88,4 +98,6 @@ class GatedFeedForward(nn.Module):
88
98
  Returns:
89
99
  torch.Tensor: output tensor after feedforward.
90
100
  """
91
- return self.w2(self.act(self.w1(x)) * self.w3(x))
101
+ x_norm = self.pre_ff_norm(x)
102
+ out = self.w2(self.act(self.w1(x_norm)) * self.w3(x_norm))
103
+ return self.post_ff_norm(out)
@@ -16,7 +16,7 @@
16
16
  from dataclasses import dataclass
17
17
  from dataclasses import field
18
18
  import enum
19
- from typing import Optional
19
+ from typing import Optional, Sequence
20
20
 
21
21
 
22
22
  @enum.unique
@@ -53,6 +53,11 @@ class FeedForwardType(enum.Enum):
53
53
  GATED = enum.auto()
54
54
 
55
55
 
56
+ class AttentionType(enum.Enum):
57
+ GLOBAL = enum.auto()
58
+ LOCAL_SLIDING = enum.auto()
59
+
60
+
56
61
  @dataclass
57
62
  class AttentionConfig:
58
63
  """Attention model's parameters."""
@@ -78,6 +83,12 @@ class AttentionConfig:
78
83
  enable_kv_cache: bool = True
79
84
  relative_attention_num_buckets: int = 0
80
85
  relative_attention_max_distance: int = 0
86
+ # Softcap on the output logits.
87
+ logit_softcap: Optional[float] = None
88
+ # The types of attention used in the layers of the model.
89
+ attn_types: Optional[Sequence[AttentionType]] = None
90
+ # The size of the sliding window used for local attention.
91
+ sliding_window_size: Optional[int] = None
81
92
 
82
93
 
83
94
  @dataclass
@@ -88,16 +99,6 @@ class ActivationConfig:
88
99
  dim_out: Optional[int] = None
89
100
 
90
101
 
91
- @dataclass
92
- class FeedForwardConfig:
93
- """FeedForward module's parameters."""
94
-
95
- type: FeedForwardType
96
- activation: ActivationConfig
97
- intermediate_size: int
98
- use_bias: bool = False
99
-
100
-
101
102
  @dataclass
102
103
  class NormalizationConfig:
103
104
  """Normalizater parameters."""
@@ -109,6 +110,24 @@ class NormalizationConfig:
109
110
  group_num: Optional[float] = None
110
111
 
111
112
 
113
+ @dataclass
114
+ class FeedForwardConfig:
115
+ """FeedForward module's parameters."""
116
+
117
+ type: FeedForwardType
118
+ activation: ActivationConfig
119
+ intermediate_size: int
120
+ use_bias: bool = False
121
+ # The normalization applied to feed forward's input.
122
+ pre_ff_norm_config: NormalizationConfig = field(
123
+ default_factory=NormalizationConfig
124
+ )
125
+ # The normalization applied to feed forward's output.
126
+ post_ff_norm_config: NormalizationConfig = field(
127
+ default_factory=NormalizationConfig
128
+ )
129
+
130
+
112
131
  @dataclass
113
132
  class ModelConfig:
114
133
  """Base configurations for building a transformer architecture."""
@@ -124,8 +143,8 @@ class ModelConfig:
124
143
  pre_attention_norm_config: NormalizationConfig = field(
125
144
  default_factory=NormalizationConfig
126
145
  )
127
- # The normalization applied to feed forward's input.
128
- pre_ff_norm_config: NormalizationConfig = field(
146
+ # The normalization applied to attentions's output.
147
+ post_attention_norm_config: NormalizationConfig = field(
129
148
  default_factory=NormalizationConfig
130
149
  )
131
150
  # The normalization applied before LM head.
@@ -151,6 +170,9 @@ class ModelConfig:
151
170
  # Default batch size of the exported model. Default value is 1.
152
171
  batch_size: int = 1
153
172
 
173
+ # Softcap on the model output logits.
174
+ final_logit_softcap: Optional[float] = None
175
+
154
176
  @property
155
177
  def kv_cache_max(self) -> int:
156
178
  if self.kv_cache_max_len > 0:
@@ -29,6 +29,7 @@ def scaled_dot_product_attention(
29
29
  head_size: int,
30
30
  mask: Optional[torch.Tensor] = None,
31
31
  scale: Optional[float] = None,
32
+ softcap: Optional[float] = None,
32
33
  ):
33
34
  """Scaled dot product attention.
34
35
 
@@ -53,15 +54,26 @@ def scaled_dot_product_attention(
53
54
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
54
55
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
55
56
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
56
- y = F.scaled_dot_product_attention(
57
- q,
58
- k,
59
- v,
60
- attn_mask=mask,
61
- dropout_p=0.0,
62
- is_causal=mask is None,
63
- scale=scale,
64
- )
57
+ if softcap is None:
58
+ y = F.scaled_dot_product_attention(
59
+ q,
60
+ k,
61
+ v,
62
+ attn_mask=mask,
63
+ dropout_p=0.0,
64
+ is_causal=mask is None,
65
+ scale=scale,
66
+ )
67
+ else:
68
+ q.mul_(scale)
69
+ scores = q @ k.transpose(-1, -2)
70
+ scores = scores / softcap
71
+ scores = torch.tanh(scores)
72
+ scores = scores * softcap
73
+ scores = scores + mask
74
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
75
+ y = torch.matmul(out, v)
76
+
65
77
  return y.transpose(1, 2)
66
78
 
67
79
 
@@ -72,6 +84,7 @@ def scaled_dot_product_attention_with_hlfb(
72
84
  head_size: int,
73
85
  mask: Optional[torch.Tensor] = None,
74
86
  scale: Optional[float] = None,
87
+ softcap: Optional[float] = None,
75
88
  ):
76
89
  """Scaled dot product attention with high-level function boundary enabled.
77
90
 
@@ -86,6 +99,9 @@ def scaled_dot_product_attention_with_hlfb(
86
99
  The output tensor of scaled_dot_product_attention.
87
100
  """
88
101
 
102
+ if softcap is not None:
103
+ raise NotImplementedError("SDPA with HLFB not available with softcap.")
104
+
89
105
  if scale is None:
90
106
  scale = 1.0 / math.sqrt(head_size)
91
107
 
@@ -16,7 +16,7 @@
16
16
  import copy
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch.generative.examples.gemma import gemma
19
+ from ai_edge_torch.generative.examples.gemma import gemma, gemma2
20
20
  from ai_edge_torch.generative.examples.phi2 import phi2
21
21
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
22
22
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
@@ -202,6 +202,34 @@ class TestModelConversion(googletest.TestCase):
202
202
  )
203
203
  )
204
204
 
205
+ def test_gemma2(self):
206
+ self.skipTest("b/338288901")
207
+ config = gemma2.get_fake_model_config_2b_for_test()
208
+ model = gemma2.Gemma2(config)
209
+ model.eval()
210
+
211
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
212
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
213
+ tokens[0, :4] = idx
214
+ input_pos = torch.arange(0, 10)
215
+
216
+ edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
217
+
218
+ # TODO: b/338288901 - re-enable test to check output tensors.
219
+ skip_output_check = True
220
+ if not skip_output_check:
221
+ # TODO(talumbau, haoliang): debug numerical diff.
222
+ self.assertTrue(
223
+ model_coverage.compare_tflite_torch(
224
+ edge_model,
225
+ model,
226
+ (tokens, input_pos),
227
+ num_valid_inputs=1,
228
+ atol=1e-2,
229
+ rtol=1e-5,
230
+ )
231
+ )
232
+
205
233
  def test_phi2(self):
206
234
  self.skipTest("b/338288901")
207
235
  config = phi2.get_fake_model_config_for_test()
@@ -107,7 +107,9 @@ class ModelLoader:
107
107
  ff_gate_proj: str = None
108
108
 
109
109
  pre_attn_norm: str = None
110
+ post_attn_norm: str = None
110
111
  pre_ff_norm: str = None
112
+ post_ff_norm: str = None
111
113
  embedding: str = None
112
114
  embedding_position: str = None
113
115
  final_norm: str = None
@@ -258,6 +260,26 @@ class ModelLoader:
258
260
  f"{ff_gate_proj_name}.bias"
259
261
  )
260
262
 
263
+ if self._names.pre_ff_norm is not None:
264
+ pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
265
+ converted_state[f"{prefix}.ff.pre_ff_norm.weight"] = state.pop(
266
+ f"{pre_ff_norm_name}.weight"
267
+ )
268
+ if f"{pre_ff_norm_name}.bias" in state:
269
+ converted_state[f"{prefix}.ff.pre_ff_norm.bias"] = state.pop(
270
+ f"{pre_ff_norm_name}.bias"
271
+ )
272
+
273
+ if self._names.post_ff_norm is not None:
274
+ post_ff_norm_name = self._names.post_ff_norm.format(idx)
275
+ converted_state[f"{prefix}.ff.post_ff_norm.weight"] = state.pop(
276
+ f"{post_ff_norm_name}.weight"
277
+ )
278
+ if f"{post_ff_norm_name}.bias" in state:
279
+ converted_state[f"{prefix}.ff.post_ff_norm.bias"] = state.pop(
280
+ f"{post_ff_norm_name}.bias"
281
+ )
282
+
261
283
  def _map_attention(
262
284
  self,
263
285
  idx: int,
@@ -325,14 +347,14 @@ class ModelLoader:
325
347
  f"{pre_attn_norm_name}.bias"
326
348
  )
327
349
 
328
- if self._names.pre_ff_norm is not None:
329
- pre_ff_norm_name = self._names.pre_ff_norm.format(idx)
330
- converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
331
- f"{pre_ff_norm_name}.weight"
350
+ if self._names.post_attn_norm is not None:
351
+ post_attn_norm_name = self._names.post_attn_norm.format(idx)
352
+ converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
353
+ f"{post_attn_norm_name}.weight"
332
354
  )
333
- if f"{pre_ff_norm_name}.bias" in state:
334
- converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
335
- f"{pre_ff_norm_name}.bias"
355
+ if f"{post_attn_norm_name}.bias" in state:
356
+ converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
357
+ f"{post_attn_norm_name}.bias"
336
358
  )
337
359
 
338
360
  def _fuse_qkv(
@@ -113,7 +113,7 @@ class ModelLoader:
113
113
 
114
114
  pre_attn_norm: str = None
115
115
  pre_cross_attn_norm: str = None
116
- pre_ff_norm: str = None
116
+ post_attn_norm: str = None
117
117
  embedding: str = None
118
118
  final_norm: str = None
119
119
  lm_head: str = None
@@ -484,14 +484,14 @@ class ModelLoader:
484
484
  state.pop(f"{pre_cross_attn_norm_name}.bias")
485
485
  )
486
486
 
487
- if names.pre_ff_norm is not None:
488
- pre_ff_norm_name = names.pre_ff_norm.format(idx)
489
- converted_state[f"{prefix}.pre_ff_norm.weight"] = state.pop(
490
- f"{pre_ff_norm_name}.weight"
487
+ if names.post_attn_norm is not None:
488
+ post_attn_norm_name = names.post_attn_norm.format(idx)
489
+ converted_state[f"{prefix}.post_atten_norm.weight"] = state.pop(
490
+ f"{post_attn_norm_name}.weight"
491
491
  )
492
- if f"{pre_ff_norm_name}.bias" in state:
493
- converted_state[f"{prefix}.pre_ff_norm.bias"] = state.pop(
494
- f"{pre_ff_norm_name}.bias"
492
+ if f"{post_attn_norm_name}.bias" in state:
493
+ converted_state[f"{prefix}.post_atten_norm.bias"] = state.pop(
494
+ f"{post_attn_norm_name}.bias"
495
495
  )
496
496
 
497
497
  def _fuse_qkv(
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240813"
16
+ __version__ = "0.3.0.dev20240815"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240813
3
+ Version: 0.3.0.dev20240815
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=5DYNpFVwvI1w0JbAC1hn83NJVGS1WPX7n742419PMqs,4558
5
- ai_edge_torch/version.py,sha256=C9Lsgh_kXnELi8xLPUgnmDTLFOKW5S5z6lXAVwLMypU,706
5
+ ai_edge_torch/version.py,sha256=Ac80tDmZqQ9SHOuHGuT0szYDNOLnaak-9rePRq6P7eE,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -42,22 +42,24 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
42
42
  ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
43
  ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
44
  ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=QoFbUUCTJrW1IYZg0vfb2-K-X0q1-NJFbWNGPQGwBgk,6688
45
+ ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=8313wSsddvuxZ5ZYVdaITBV2FF1k22dcCujnq0UZvKs,6699
46
46
  ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
48
  ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
49
49
  ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=GOLLd9yCBnlNXeW7xrVy1wjOltcTbRdSpiJycbMj8TA,6372
51
+ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=zQYtyk3xYdiRAnzMKN58Q_wgTQFnDujxp6L4RFQjiD4,6383
52
52
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
53
54
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
54
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=5Dn9JgJiXN-hWGQj9YqCr8Iik8mh5s0dX0VfyY8KDDo,6236
55
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
56
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=j-zxJ-JNRnQ_kDzUESmsyy_a_4IxWZ510HmIImc0LDc,8240
55
57
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
58
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
57
59
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
58
60
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
61
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
60
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=VR09iAnj1e-sr-oam2rh24Wnb_JdZZQvpJIjylfgnS8,4468
62
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
61
63
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7ra36nM5tQwSw-vi6QCFLx5IssZhT-6yVK4H3XsAc4w,5044
62
64
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
63
65
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
@@ -72,27 +74,27 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
72
74
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
73
75
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
74
76
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz70WOvNpTJ9LFkiDnlwgJiXfUZCVk,4548
75
- ai_edge_torch/generative/examples/t5/t5.py,sha256=6Rkisv7UI2w5KV8ogPPzeIiPWYwDLfFfSIncqD7Eenc,20854
76
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=gp7DV8pv4FwICQhYlUYfYZ7BE5jzDIsD_V3a_4-T4Ds,8492
77
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
77
79
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=DhxOrIKe-tilBjbh1q4MsmCmmKMc4c1BPUzhnaJDD6M,3955
79
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=bW0QB-_h9cfwAQf11AxFxOBq3HrEep_UlpBjXz3JSew,5801
80
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=CRja_CT0_eaH16rSDxwHKJS_CGUJMW0Fxd4r45Ii8Uo,4833
80
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
81
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
81
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
84
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
83
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=nu3Il8Vxe7JwM8-AnGNXoGoZ9eVXKHMYEAqVEP-gwe8,5929
85
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mXXFYJfo8yegSOFOndCR0oYxFPchYb9vTJ4ThXGIFLU,5940
84
86
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
85
87
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
86
88
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
- ai_edge_torch/generative/layers/attention.py,sha256=xq10Gw4GudK4M2eY8-H4fi3qmpmZCfE-CziAXDZvqiQ,12177
88
- ai_edge_torch/generative/layers/attention_utils.py,sha256=2hzBVZvWCqqLfI-f3RJA1hi6T8cuaIJBPt8cdjQCA5s,6420
89
- ai_edge_torch/generative/layers/builder.py,sha256=JvPmwrG8_M4-kO2MM6sDZhpS32Wx3wVVhlVO4yPJKJ0,4161
90
- ai_edge_torch/generative/layers/feed_forward.py,sha256=RukSYr9h_DehcYVZWLS_rfCTY73Uj__pTRUatjxJtv8,2788
89
+ ai_edge_torch/generative/layers/attention.py,sha256=2UujQePRJ1LK02PN-hGcuMu0ooCJC6ETfPvzEYVFyho,12284
90
+ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
91
+ ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
92
+ ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
91
93
  ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lfvKca8JbWtFx2CE,3082
92
- ai_edge_torch/generative/layers/model_config.py,sha256=CTvKFwsBR3Rc-Kf73NA7k0799m1WnEvaEBKCnnfNkyo,4961
94
+ ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
93
95
  ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
94
96
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
95
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=6WMe-A5KSSujQcZ34hIeSnnor3AXrw10cQ5FKy-30IU,3390
97
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=x2bOmrTgOISXcb06IDP7X3xgftpPpxOjBXw_OxTMVns,3874
96
98
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
99
  ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
98
100
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
@@ -109,12 +111,12 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha
109
111
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
112
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=T5-O2RVLJTH7v9w1_uBfp-Y7o3sdGzYq2Tj2wLRNHyI,4357
111
113
  ai_edge_torch/generative/test/test_loader.py,sha256=1ZqAq0HY5uIioumsReOVIsbGBx0WkYcl18PvttdJKrk,3381
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=4RTB1oPA2eWPyuof2-ZB1BxVKzKy5Q9vCux7psmV6zc,7615
114
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=52ciFy_Qol2Xuym6P6EqdL29oai35LSWGvsUwyEdFTo,8477
113
115
  ai_edge_torch/generative/test/test_quantize.py,sha256=3SmJm7Kq98gAneU6IGwwJrJYCVH1qwWR6oUxPfb6qiI,5346
114
116
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
115
- ai_edge_torch/generative/utilities/loader.py,sha256=XfVRvwvZyQuofctxIedLNDKQrsy9UlRr4wpScZJLWcw,11779
117
+ ai_edge_torch/generative/utilities/loader.py,sha256=bAWZ7FM4v_pPnX_AmEdGxHkDH65QdL-MjIP3PxscZmI,12649
116
118
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
117
- ai_edge_torch/generative/utilities/t5_loader.py,sha256=jz2qnDtH6oyxcqaBwEVfiiKmq_93LTDeUKNJ2cWpLwg,16856
119
+ ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
118
120
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
119
121
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
120
122
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -134,8 +136,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
134
136
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
135
137
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
136
138
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
137
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
138
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/METADATA,sha256=TMkI635DYqK0Fg6W6tZbg8ZTT54_9QkkCcd3XOxjyho,1885
139
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
140
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
141
- ai_edge_torch_nightly-0.3.0.dev20240813.dist-info/RECORD,,
139
+ ai_edge_torch_nightly-0.3.0.dev20240815.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
140
+ ai_edge_torch_nightly-0.3.0.dev20240815.dist-info/METADATA,sha256=aK9kjCcC_P6dkcvL_vcrXJLg7sn3Pfb2jlpSMrrKJ6Q,1885
141
+ ai_edge_torch_nightly-0.3.0.dev20240815.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
142
+ ai_edge_torch_nightly-0.3.0.dev20240815.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
143
+ ai_edge_torch_nightly-0.3.0.dev20240815.dist-info/RECORD,,