ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (23) hide show
  1. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  2. ai_edge_torch/generative/examples/openelm/verify.py +61 -0
  3. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  4. ai_edge_torch/generative/examples/phi/verify.py +53 -0
  5. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  6. ai_edge_torch/generative/examples/smollm/verify.py +59 -0
  7. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  8. ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
  9. ai_edge_torch/generative/layers/builder.py +3 -1
  10. ai_edge_torch/generative/layers/model_config.py +3 -0
  11. ai_edge_torch/generative/layers/normalization.py +31 -20
  12. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  13. ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
  14. ai_edge_torch/generative/layers/unet/model_config.py +1 -0
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
  16. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
  17. ai_edge_torch/generative/utilities/verifier.py +200 -0
  18. ai_edge_torch/version.py +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +23 -18
  21. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
206
  loader.load(model, strict=False)
211
207
  model.eval()
212
208
  return model
213
-
214
-
215
- def define_and_run(checkpoint_path: str) -> None:
216
- """Instantiates and runs an OpenELM model."""
217
-
218
- current_dir = pathlib.Path(__file__).parent.resolve()
219
- openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
- kv_cache_max_len = 1024
221
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
- tokens[0, :4] = idx
225
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
- kv = kv_utils.KVCache.from_model_config(model.config)
227
- output = model.forward(tokens, input_pos, kv)
228
- assert torch.allclose(
229
- openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
- )
231
-
232
-
233
- if __name__ == "__main__":
234
- input_checkpoint_path = os.path.join(
235
- pathlib.Path.home(), "Downloads/llm_data/openelm"
236
- )
237
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "apple/OpenELM-3B"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
47
+
48
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
49
+ verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=original_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ prompts=_PROMPTS.value,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
139
  intermediate_size=10240,
144
140
  use_bias=True,
145
141
  )
146
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
142
+ norm_config = cfg.NormalizationConfig(
143
+ type=cfg.NormalizationType.LAYER_NORM,
144
+ use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
145
+ )
147
146
  block_config = cfg.TransformerBlockConfig(
148
147
  attn_config=attn_config,
149
148
  ff_config=ff_config,
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
182
181
  loader.load(model)
183
182
  model.eval()
184
183
  return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str) -> None:
188
- """Instantiates and runs a Phi-2 model."""
189
-
190
- current_dir = pathlib.Path(__file__).parent.resolve()
191
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
- kv_cache_max_len = 1024
193
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
- kv = kv_utils.KVCache.from_model_config(model.config)
199
- output = model.forward(tokens, input_pos, kv)
200
- print("comparing with goldens..")
201
- assert torch.allclose(
202
- phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
- )
204
-
205
-
206
- if __name__ == "__main__":
207
- input_checkpoint_path = os.path.join(
208
- pathlib.Path.home(), "Downloads/llm_data/phi2"
209
- )
210
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,53 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.phi import phi2
21
+ from ai_edge_torch.generative.utilities import verifier
22
+ import kagglehub
23
+ import transformers
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+
31
+
32
+ def main(_):
33
+ checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
34
+ verifier.log_msg("Loading the original model from", checkpoint)
35
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
36
+
37
+ verifier.log_msg("Building the reauthored model from", checkpoint)
38
+ reauthored_model = phi2.build_model(checkpoint)
39
+
40
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
41
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
42
+
43
+ verifier.verify_reauthored_model(
44
+ original_model=original_model,
45
+ reauthored_model=reauthored_model,
46
+ tokenizer=tokenizer,
47
+ prompts=_PROMPTS.value,
48
+ atol=1e-03,
49
+ )
50
+
51
+
52
+ if __name__ == "__main__":
53
+ app.run(main)
@@ -16,15 +16,10 @@
16
16
  """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
- import os
20
- import pathlib
21
19
 
22
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.model_config as cfg
25
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
- import numpy as np
27
- import torch
28
23
  from torch import nn
29
24
 
30
25
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
104
99
  loader.load(model, strict=False)
105
100
  model.eval()
106
101
  return model
107
-
108
-
109
- def define_and_run(checkpoint_path: str) -> None:
110
- """Instantiates and runs a SmolLM model."""
111
-
112
- current_dir = pathlib.Path(__file__).parent.resolve()
113
- smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
114
- kv_cache_max_len = 1024
115
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
116
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
118
- tokens[0, :4] = idx
119
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
120
- kv = kv_utils.KVCache.from_model_config(model.config)
121
- output = model.forward(tokens, input_pos, kv)
122
- assert torch.allclose(
123
- smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
124
- )
125
-
126
-
127
- if __name__ == "__main__":
128
- input_checkpoint_path = os.path.join(
129
- pathlib.Path.home(), "Downloads/llm_data/smollm"
130
- )
131
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored SmolLM-135M model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "HuggingFaceTB/SmolLM-135M"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
37
+
38
+ # Locate the cached dir.
39
+ cached_config_file = transformers.utils.cached_file(
40
+ checkpoint, transformers.utils.CONFIG_NAME
41
+ )
42
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
43
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
44
+ reauthored_model = smollm.build_model(reauthored_checkpoint)
45
+
46
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
47
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
48
+
49
+ verifier.verify_reauthored_model(
50
+ original_model=original_model,
51
+ reauthored_model=reauthored_model,
52
+ tokenizer=tokenizer,
53
+ prompts=_PROMPTS.value,
54
+ atol=1e-04,
55
+ )
56
+
57
+
58
+ if __name__ == "__main__":
59
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
179
175
  loader.load(model)
180
176
  model.eval()
181
177
  return model
182
-
183
-
184
- def define_and_run(checkpoint_path: str) -> None:
185
- """Instantiates and runs a TinyLlama model."""
186
-
187
- current_dir = pathlib.Path(__file__).parent.resolve()
188
- tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
189
- kv_cache_max_len = 1024
190
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
- tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
- kv = kv_utils.KVCache.from_model_config(model.config)
196
- output = model.forward(tokens, input_pos, kv)
197
- assert torch.allclose(
198
- tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
199
- )
200
-
201
-
202
- if __name__ == "__main__":
203
- input_checkpoint_path = os.path.join(
204
- pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
- )
206
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored TinyLlama-1.1B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Show me the program to add 2 and 3.",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
47
+
48
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
49
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
+
51
+ verifier.verify_reauthored_model(
52
+ original_model=original_model,
53
+ reauthored_model=reauthored_model,
54
+ tokenizer=tokenizer,
55
+ prompts=_PROMPTS.value,
56
+ atol=1e-04,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
75
75
  zero_centered_gamma=config.zero_centered,
76
76
  )
77
77
  elif config.type == cfg.NormalizationType.LAYER_NORM:
78
- return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
78
+ return normalization.LayerNorm(
79
+ dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ )
79
81
  elif config.type == cfg.NormalizationType.GROUP_NORM:
80
82
  return normalization.GroupNorm(
81
83
  config.group_num, dim, config.epsilon, config.enable_hlfb
@@ -69,6 +69,9 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
+ # Whether to use the input shape to determine the dimension of normalization
73
+ # when type is LAYER_NORM.
74
+ use_input_shape: bool = True
72
75
 
73
76
 
74
77
  @dataclass
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
78
78
  group_num (int): Number of groups to separate the channels into.
79
79
  dim (int): Dimension of the input tensor.
80
80
  eps (float): A small float value to ensure numerical stability (default:
81
- 1e-6).
81
+ 1e-5).
82
82
  enable_hlfb (bool): Whether to convert this normalization into a single
83
83
  op.
84
84
  """
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
112
112
 
113
113
  class LayerNorm(torch.nn.Module):
114
114
 
115
- def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ eps: float = 1e-5,
119
+ enable_hlfb: bool = False,
120
+ use_input_shape: bool = True,
121
+ ):
116
122
  """Initialize the LayerNorm layer.
117
123
 
118
124
  Args:
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
121
127
  1e-6).
122
128
  enable_hlfb (bool): Whether to convert this normalization into a single
123
129
  op.
130
+ use_input_shape (bool): Whether to use the input shape to determine the
131
+ dimension of normalization (default: True).
124
132
  """
125
133
  super().__init__()
126
134
  self.enable_hlfb = enable_hlfb
135
+ self.use_input_shape = use_input_shape
127
136
  self.eps = eps
128
137
  self.weight = torch.nn.Parameter(torch.ones(dim))
129
138
  self.bias = torch.nn.Parameter(torch.ones(dim))
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
139
148
  """
140
149
  if self.enable_hlfb:
141
150
  return layer_norm_with_hlfb(
142
- x,
143
- self.weight,
144
- self.bias,
145
- self.eps,
151
+ x, self.weight, self.bias, self.eps, self.use_input_shape
146
152
  )
153
+
154
+ if self.use_input_shape:
155
+ normalized_shape = x.shape
156
+ weight = self.weight.broadcast_to(x.shape)
157
+ bias = self.bias.broadcast_to(x.shape)
147
158
  else:
148
- return F.layer_norm(
149
- x,
150
- x.shape,
151
- self.weight.broadcast_to(x.shape),
152
- self.bias.broadcast_to(x.shape),
153
- self.eps,
154
- )
159
+ normalized_shape = self.weight.shape
160
+ weight = self.weight
161
+ bias = self.bias
162
+ return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
155
163
 
156
164
 
157
165
  def group_norm_with_hlfb(
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
193
201
  w: torch.Tensor,
194
202
  b: torch.Tensor,
195
203
  eps: float,
204
+ use_input_shape: bool,
196
205
  ):
197
206
  """Layer Normalization with high-level function boundary enabled.
198
207
 
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
201
210
  w (torch.Tensor): The weight tensor for the normalization.
202
211
  b (torch.Tensor): The bias tensor for the normalization.
203
212
  eps (float): A small float value to ensure numerical stability.
213
+ use_input_shape (bool): Whether to use the input shape to determine the
214
+ dimension of normalization.
204
215
 
205
216
  Returns:
206
217
  The output tensor of Layer Normalization.
207
218
  """
208
219
  builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
220
  x, w, b = builder.mark_inputs(x, w, b)
210
- y = F.layer_norm(
211
- x,
212
- x.shape,
213
- weight=w.broadcast_to(x.shape),
214
- bias=b.broadcast_to(x.shape),
215
- eps=eps,
216
- )
221
+ if use_input_shape:
222
+ normalized_shape = x.shape
223
+ w = w.broadcast_to(x.shape)
224
+ b = b.broadcast_to(x.shape)
225
+ else:
226
+ normalized_shape = w.shape
227
+ y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
217
228
  y = builder.mark_outputs(y)
218
229
  return y
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
119
119
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
120
120
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
121
121
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
122
- y = F.scaled_dot_product_attention(
123
- q,
124
- k,
125
- v,
126
- attn_mask=mask,
127
- dropout_p=0.0,
128
- is_causal=mask is None,
129
- scale=scale,
130
- )
122
+ if softcap is None:
123
+ y = F.scaled_dot_product_attention(
124
+ q,
125
+ k,
126
+ v,
127
+ attn_mask=mask,
128
+ dropout_p=0.0,
129
+ is_causal=mask is None,
130
+ scale=scale,
131
+ )
132
+ else:
133
+ q.mul_(scale)
134
+ scores = q @ k.transpose(-1, -2)
135
+ scores = scores / softcap
136
+ scores = torch.tanh(scores)
137
+ scores = scores * softcap
138
+ scores = scores + mask
139
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
140
+ y = torch.matmul(out, v)
131
141
 
132
142
  result = y.transpose(1, 2)
133
143
  result = builder.mark_outputs(result)
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
41
41
  )
42
42
  self.conv_1 = nn.Conv2d(
43
43
  config.in_channels,
44
- config.out_channels,
44
+ config.hidden_channels,
45
45
  kernel_size=3,
46
46
  stride=1,
47
47
  padding=1,
48
48
  )
49
49
  if config.time_embedding_channels is not None:
50
50
  self.time_emb_proj = nn.Linear(
51
- config.time_embedding_channels, config.out_channels
51
+ config.time_embedding_channels, config.hidden_channels
52
52
  )
53
53
  else:
54
54
  self.time_emb_proj = None
55
55
  self.norm_2 = layers_builder.build_norm(
56
- config.out_channels, config.normalization_config
56
+ config.hidden_channels, config.normalization_config
57
57
  )
58
58
  self.conv_2 = nn.Conv2d(
59
- config.out_channels,
59
+ config.hidden_channels,
60
60
  config.out_channels,
61
61
  kernel_size=3,
62
62
  stride=1,
@@ -391,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
391
391
  ResidualBlock2D(
392
392
  unet_cfg.ResidualBlock2DConfig(
393
393
  in_channels=input_channels,
394
+ hidden_channels=config.out_channels,
394
395
  out_channels=config.out_channels,
395
396
  time_embedding_channels=config.time_embedding_channels,
396
397
  normalization_config=config.normalization_config,
@@ -492,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
492
493
  ResidualBlock2D(
493
494
  unet_cfg.ResidualBlock2DConfig(
494
495
  in_channels=input_channels,
496
+ hidden_channels=config.out_channels,
495
497
  out_channels=config.out_channels,
496
498
  time_embedding_channels=config.time_embedding_channels,
497
499
  normalization_config=config.normalization_config,
@@ -602,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
602
604
  ResidualBlock2D(
603
605
  unet_cfg.ResidualBlock2DConfig(
604
606
  in_channels=resnet_in_channels + res_skip_channels,
607
+ hidden_channels=config.out_channels,
605
608
  out_channels=config.out_channels,
606
609
  time_embedding_channels=config.time_embedding_channels,
607
610
  normalization_config=config.normalization_config,
@@ -706,6 +709,7 @@ class MidBlock2D(nn.Module):
706
709
  ResidualBlock2D(
707
710
  unet_cfg.ResidualBlock2DConfig(
708
711
  in_channels=config.in_channels,
712
+ hidden_channels=config.in_channels,
709
713
  out_channels=config.in_channels,
710
714
  time_embedding_channels=config.time_embedding_channels,
711
715
  normalization_config=config.normalization_config,
@@ -724,6 +728,7 @@ class MidBlock2D(nn.Module):
724
728
  ResidualBlock2D(
725
729
  unet_cfg.ResidualBlock2DConfig(
726
730
  in_channels=config.in_channels,
731
+ hidden_channels=config.in_channels,
727
732
  out_channels=config.in_channels,
728
733
  time_embedding_channels=config.time_embedding_channels,
729
734
  normalization_config=config.normalization_config,
@@ -48,6 +48,7 @@ class DownSamplingConfig:
48
48
  @dataclasses.dataclass
49
49
  class ResidualBlock2DConfig:
50
50
  in_channels: int
51
+ hidden_channels: int
51
52
  out_channels: int
52
53
  normalization_config: layers_cfg.NormalizationConfig
53
54
  activation_config: layers_cfg.ActivationConfig
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
96
96
  def test_gemma2(self):
97
97
  config = gemma2.get_fake_model_config()
98
98
  pytorch_model = gemma2.Gemma2(config).eval()
99
- self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
100
100
 
101
101
  @googletest.skipIf(
102
102
  ai_edge_config.Config.use_torch_xla,
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
412
412
  ):
413
413
  residual_block_config = unet_config.ResidualBlock2DConfig(
414
414
  in_channels=config.in_channels,
415
+ hidden_channels=config.in_channels,
415
416
  out_channels=config.in_channels,
416
417
  time_embedding_channels=config.time_embedding_channels,
417
418
  normalization_config=config.normalization_config,
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
466
467
  f"{converted_state_param_prefix}.resnets.{i}",
467
468
  unet_config.ResidualBlock2DConfig(
468
469
  in_channels=input_channels,
470
+ hidden_channels=config.out_channels,
469
471
  out_channels=config.out_channels,
470
472
  time_embedding_channels=config.time_embedding_channels,
471
473
  normalization_config=config.normalization_config,
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
508
510
  f"{converted_state_param_prefix}.resnets.{i}",
509
511
  unet_config.ResidualBlock2DConfig(
510
512
  in_channels=input_channels,
513
+ hidden_channels=config.out_channels,
511
514
  out_channels=config.out_channels,
512
515
  time_embedding_channels=config.time_embedding_channels,
513
516
  normalization_config=config.normalization_config,
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
554
557
  f"{converted_state_param_prefix}.resnets.{i}",
555
558
  unet_config.ResidualBlock2DConfig(
556
559
  in_channels=resnet_in_channels + res_skip_channels,
560
+ hidden_channels=config.out_channels,
557
561
  out_channels=config.out_channels,
558
562
  time_embedding_channels=config.time_embedding_channels,
559
563
  normalization_config=config.normalization_config,
@@ -0,0 +1,200 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions to verify the reauthored models."""
17
+
18
+ import datetime
19
+ from typing import List
20
+
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ def log_msg(*args):
27
+ print("[%s]" % datetime.datetime.now(), *args)
28
+
29
+
30
+ def forward(
31
+ model: torch.nn.Module,
32
+ tokens: torch.Tensor,
33
+ kv_cache: kv_utils.KVCache,
34
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
35
+ """Forwards the model reauthored with ai_edge_torch Generative API.
36
+
37
+ Args:
38
+ model (torch.nn.Module): The model to forward. It should be a model built
39
+ with ai_edge_torch Generative API.
40
+ tokens (torch.Tensor): The input tokens to forward.
41
+ kv_cache (KVCache): The KV cache to forward.
42
+
43
+ Returns:
44
+ The output logits and the updated KV cache.
45
+ """
46
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
47
+ output = model.forward(tokens, input_pos, kv_cache)
48
+ return output["logits"], output["kv_cache"]
49
+
50
+
51
+ def generate(
52
+ model: torch.nn.Module, prompts: torch.Tensor, response_len: int
53
+ ) -> torch.Tensor:
54
+ """Generates the response to the prompts.
55
+
56
+ It appends tokens output by the model to the prompts and feeds them back to
57
+ the model up to decode_len.
58
+
59
+ Args:
60
+ model (torch.nn.Module): The model to generate. It should be a model built
61
+ with ai_edge_torch Generative API.
62
+ prompts (torch.Tensor): The prompts to generate.
63
+ response_len (int): The number of tokens to generate.
64
+
65
+ Returns:
66
+ The generated tokens.
67
+ """
68
+ input_ids = prompts[0].int().tolist()
69
+ kv_cache = kv_utils.KVCache.from_model_config(model.config)
70
+ for _ in range(response_len - len(input_ids)):
71
+ logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
72
+ generated_token = logits[0][-1].argmax().item()
73
+ input_ids.append(generated_token)
74
+ return torch.tensor([input_ids])
75
+
76
+
77
+ def verify_with_input_ids(
78
+ original_model: torch.nn.Module,
79
+ reauthored_model: torch.nn.Module,
80
+ input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
81
+ kv_cache_max_len: int = 1024,
82
+ rtol: float = 1e-05,
83
+ atol: float = 1e-05,
84
+ ) -> bool:
85
+ """Verifies if the model reauthored generates the same output of the oringal.
86
+
87
+ It compares only one outputs from the original and the reauthored model.
88
+
89
+ Args:
90
+ original_model (torch.nn.Module): The original model.
91
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
92
+ Generative API.
93
+ input_ids (torch.Tensor): The input token IDs to forward.
94
+ kv_cache_max_len (int): The maximum sequence length of the KV cache.
95
+ rtol (float): The relative tolerance for the comparison.
96
+ atol (float): The absolute tolerance for the comparison.
97
+
98
+ Returns:
99
+ True if the model reauthored generates the same output of the original.
100
+ """
101
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
102
+ input_ids_len = input_ids.shape[1]
103
+ tokens[0, :input_ids_len] = input_ids
104
+
105
+ log_msg("Forwarding the original model...")
106
+ outputs_original = original_model.forward(tokens)
107
+ logits_original = outputs_original.logits[0, input_ids_len - 1, :]
108
+ log_msg("logits_original: ", logits_original)
109
+
110
+ log_msg("Forwarding the reauthored model...")
111
+ kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
112
+ outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
113
+ logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
114
+ log_msg("logits_reauthored:", logits_reauthored)
115
+
116
+ return torch.allclose(
117
+ logits_original, logits_reauthored, rtol=rtol, atol=atol
118
+ )
119
+
120
+
121
+ def verify_model_with_prompts(
122
+ original_model: torch.nn.Module,
123
+ reauthored_model: torch.nn.Module,
124
+ tokenizer: torch.nn.Module,
125
+ prompts: str,
126
+ ) -> bool:
127
+ """Verifies if the model reauthored generates the same answer of the oringal.
128
+
129
+ It compares an answer, i.e. multiple continuous outputs generated by the
130
+ original and the reauthored model.
131
+
132
+ Args:
133
+ original_model (torch.nn.Module): The original model.
134
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
135
+ Generative API.
136
+ tokenizer (torch.nn.Module): The tokenizer.
137
+ prompts (str): The input prompts to generate answers.
138
+
139
+ Returns:
140
+ True if the model reauthored generates the same answer of the original.
141
+ """
142
+ prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
143
+
144
+ log_msg("Generating answer with the original model...")
145
+ outputs_original = original_model.generate(prompt_tokens)
146
+ response_original = tokenizer.decode(outputs_original[0])
147
+ log_msg("outputs_from_original_model: [[", response_original, "]]")
148
+
149
+ log_msg("Generating answer with the reauthored model...")
150
+ generate_len = len(outputs_original[0])
151
+ outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
152
+ response_reauthored = tokenizer.decode(outputs_reauthored[0])
153
+ log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
154
+
155
+ return response_original == response_reauthored
156
+
157
+
158
+ def verify_reauthored_model(
159
+ original_model: torch.nn.Module,
160
+ reauthored_model: torch.nn.Module,
161
+ tokenizer: torch.nn.Module,
162
+ prompts: List[str],
163
+ rtol: float = 1e-05,
164
+ atol: float = 1e-05,
165
+ ):
166
+ """Verifies the reauthored model against the original model.
167
+
168
+ It verifies the reauthored model with two methods:
169
+ 1. It compares the output of the original and the reauthored model with an
170
+ arbitrary input.
171
+ 2. It compares the answer generated by the original and the reauthored model
172
+ with a prompt.
173
+
174
+ It prints out "PASS" or "FAILED" to the console.
175
+
176
+ Args:
177
+ original_model (torch.nn.Module): The original model.
178
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
179
+ Generative API.
180
+ tokenizer (torch.nn.Module): The tokenizer.
181
+ prompts (List[str]): List of the input prompts to generate answers.
182
+ rtol (float): The relative tolerance for the comparison.
183
+ atol (float): The absolute tolerance for the comparison.
184
+ """
185
+ log_msg("Verifying the reauthored model with an arbitrary input...")
186
+ if verify_with_input_ids(
187
+ original_model, reauthored_model, rtol=rtol, atol=atol
188
+ ):
189
+ log_msg("PASS")
190
+ else:
191
+ log_msg("FAILED")
192
+
193
+ for p in prompts:
194
+ log_msg("Verifying the reauthored model with prompts:", p)
195
+ if verify_model_with_prompts(
196
+ original_model, reauthored_model, tokenizer, p
197
+ ):
198
+ log_msg("PASS")
199
+ else:
200
+ log_msg("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240918"
16
+ __version__ = "0.3.0.dev20240919"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240918
3
+ Version: 0.3.0.dev20240919
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=jWg5qA8V0XqgFoqjk0SCsNWPRBeTmfrir9u0bucHYOU,706
6
+ ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -45,13 +45,16 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2
45
45
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
46
46
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
48
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
48
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
49
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
49
50
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
51
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
51
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
52
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
53
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
52
54
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
55
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
54
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
56
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
57
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
55
58
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
59
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
57
60
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
@@ -76,23 +79,24 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71Wvv
76
79
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
77
80
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
81
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
79
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
82
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
83
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
80
84
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
81
85
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
82
86
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
87
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
84
88
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
85
- ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
89
+ ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
86
90
  ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
87
91
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
88
- ai_edge_torch/generative/layers/model_config.py,sha256=zV3pA7giuKPrQdH81dpZz8D6LfGD-1YHuXuhIlypKc0,6784
89
- ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
92
+ ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
93
+ ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
90
94
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
91
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
95
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
92
96
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
97
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
94
98
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
95
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
99
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
96
100
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
101
  ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
98
102
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -104,14 +108,15 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
104
108
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
105
109
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
106
110
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
107
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5z9dvspFKO74Y_qIJ_VK0MYUoPdRf82Y,4498
111
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
108
112
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
109
113
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
110
114
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
111
115
  ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
112
116
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
113
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
117
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
114
118
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
119
+ ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
115
120
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
116
121
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
117
122
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -158,8 +163,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
158
163
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
159
164
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
160
165
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
161
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
162
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/METADATA,sha256=dMaIr8Iny84IfNGQGSrtlTGkYlH_mAMmgvGWm5-pkxM,1859
163
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
164
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
165
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/RECORD,,
166
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
167
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
168
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
169
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
170
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,