ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  2. ai_edge_torch/generative/examples/openelm/verify.py +61 -0
  3. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  4. ai_edge_torch/generative/examples/phi/verify.py +53 -0
  5. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  6. ai_edge_torch/generative/examples/smollm/verify.py +59 -0
  7. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  8. ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
  9. ai_edge_torch/generative/layers/builder.py +3 -1
  10. ai_edge_torch/generative/layers/model_config.py +3 -0
  11. ai_edge_torch/generative/layers/normalization.py +31 -20
  12. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  13. ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
  14. ai_edge_torch/generative/layers/unet/model_config.py +1 -0
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
  16. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
  17. ai_edge_torch/generative/utilities/verifier.py +200 -0
  18. ai_edge_torch/version.py +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +23 -18
  21. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
206
  loader.load(model, strict=False)
211
207
  model.eval()
212
208
  return model
213
-
214
-
215
- def define_and_run(checkpoint_path: str) -> None:
216
- """Instantiates and runs an OpenELM model."""
217
-
218
- current_dir = pathlib.Path(__file__).parent.resolve()
219
- openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
- kv_cache_max_len = 1024
221
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
- tokens[0, :4] = idx
225
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
- kv = kv_utils.KVCache.from_model_config(model.config)
227
- output = model.forward(tokens, input_pos, kv)
228
- assert torch.allclose(
229
- openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
- )
231
-
232
-
233
- if __name__ == "__main__":
234
- input_checkpoint_path = os.path.join(
235
- pathlib.Path.home(), "Downloads/llm_data/openelm"
236
- )
237
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "apple/OpenELM-3B"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
47
+
48
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
49
+ verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=original_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ prompts=_PROMPTS.value,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
139
  intermediate_size=10240,
144
140
  use_bias=True,
145
141
  )
146
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
142
+ norm_config = cfg.NormalizationConfig(
143
+ type=cfg.NormalizationType.LAYER_NORM,
144
+ use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
145
+ )
147
146
  block_config = cfg.TransformerBlockConfig(
148
147
  attn_config=attn_config,
149
148
  ff_config=ff_config,
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
182
181
  loader.load(model)
183
182
  model.eval()
184
183
  return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str) -> None:
188
- """Instantiates and runs a Phi-2 model."""
189
-
190
- current_dir = pathlib.Path(__file__).parent.resolve()
191
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
- kv_cache_max_len = 1024
193
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
- kv = kv_utils.KVCache.from_model_config(model.config)
199
- output = model.forward(tokens, input_pos, kv)
200
- print("comparing with goldens..")
201
- assert torch.allclose(
202
- phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
- )
204
-
205
-
206
- if __name__ == "__main__":
207
- input_checkpoint_path = os.path.join(
208
- pathlib.Path.home(), "Downloads/llm_data/phi2"
209
- )
210
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,53 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.phi import phi2
21
+ from ai_edge_torch.generative.utilities import verifier
22
+ import kagglehub
23
+ import transformers
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+
31
+
32
+ def main(_):
33
+ checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
34
+ verifier.log_msg("Loading the original model from", checkpoint)
35
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
36
+
37
+ verifier.log_msg("Building the reauthored model from", checkpoint)
38
+ reauthored_model = phi2.build_model(checkpoint)
39
+
40
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
41
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
42
+
43
+ verifier.verify_reauthored_model(
44
+ original_model=original_model,
45
+ reauthored_model=reauthored_model,
46
+ tokenizer=tokenizer,
47
+ prompts=_PROMPTS.value,
48
+ atol=1e-03,
49
+ )
50
+
51
+
52
+ if __name__ == "__main__":
53
+ app.run(main)
@@ -16,15 +16,10 @@
16
16
  """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
- import os
20
- import pathlib
21
19
 
22
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.model_config as cfg
25
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
- import numpy as np
27
- import torch
28
23
  from torch import nn
29
24
 
30
25
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
104
99
  loader.load(model, strict=False)
105
100
  model.eval()
106
101
  return model
107
-
108
-
109
- def define_and_run(checkpoint_path: str) -> None:
110
- """Instantiates and runs a SmolLM model."""
111
-
112
- current_dir = pathlib.Path(__file__).parent.resolve()
113
- smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
114
- kv_cache_max_len = 1024
115
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
116
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
118
- tokens[0, :4] = idx
119
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
120
- kv = kv_utils.KVCache.from_model_config(model.config)
121
- output = model.forward(tokens, input_pos, kv)
122
- assert torch.allclose(
123
- smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
124
- )
125
-
126
-
127
- if __name__ == "__main__":
128
- input_checkpoint_path = os.path.join(
129
- pathlib.Path.home(), "Downloads/llm_data/smollm"
130
- )
131
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,59 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored SmolLM-135M model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "HuggingFaceTB/SmolLM-135M"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
37
+
38
+ # Locate the cached dir.
39
+ cached_config_file = transformers.utils.cached_file(
40
+ checkpoint, transformers.utils.CONFIG_NAME
41
+ )
42
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
43
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
44
+ reauthored_model = smollm.build_model(reauthored_checkpoint)
45
+
46
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
47
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
48
+
49
+ verifier.verify_reauthored_model(
50
+ original_model=original_model,
51
+ reauthored_model=reauthored_model,
52
+ tokenizer=tokenizer,
53
+ prompts=_PROMPTS.value,
54
+ atol=1e-04,
55
+ )
56
+
57
+
58
+ if __name__ == "__main__":
59
+ app.run(main)
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
179
175
  loader.load(model)
180
176
  model.eval()
181
177
  return model
182
-
183
-
184
- def define_and_run(checkpoint_path: str) -> None:
185
- """Instantiates and runs a TinyLlama model."""
186
-
187
- current_dir = pathlib.Path(__file__).parent.resolve()
188
- tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
189
- kv_cache_max_len = 1024
190
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
- tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
- kv = kv_utils.KVCache.from_model_config(model.config)
196
- output = model.forward(tokens, input_pos, kv)
197
- assert torch.allclose(
198
- tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
199
- )
200
-
201
-
202
- if __name__ == "__main__":
203
- input_checkpoint_path = os.path.join(
204
- pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
- )
206
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored TinyLlama-1.1B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Show me the program to add 2 and 3.",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
+ checkpoint, trust_remote_code=True
38
+ )
39
+
40
+ # Locate the cached dir.
41
+ cached_config_file = transformers.utils.cached_file(
42
+ checkpoint, transformers.utils.CONFIG_NAME
43
+ )
44
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
45
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
46
+ reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
47
+
48
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
49
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
+
51
+ verifier.verify_reauthored_model(
52
+ original_model=original_model,
53
+ reauthored_model=reauthored_model,
54
+ tokenizer=tokenizer,
55
+ prompts=_PROMPTS.value,
56
+ atol=1e-04,
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ app.run(main)
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
75
75
  zero_centered_gamma=config.zero_centered,
76
76
  )
77
77
  elif config.type == cfg.NormalizationType.LAYER_NORM:
78
- return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
78
+ return normalization.LayerNorm(
79
+ dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ )
79
81
  elif config.type == cfg.NormalizationType.GROUP_NORM:
80
82
  return normalization.GroupNorm(
81
83
  config.group_num, dim, config.epsilon, config.enable_hlfb
@@ -69,6 +69,9 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
+ # Whether to use the input shape to determine the dimension of normalization
73
+ # when type is LAYER_NORM.
74
+ use_input_shape: bool = True
72
75
 
73
76
 
74
77
  @dataclass
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
78
78
  group_num (int): Number of groups to separate the channels into.
79
79
  dim (int): Dimension of the input tensor.
80
80
  eps (float): A small float value to ensure numerical stability (default:
81
- 1e-6).
81
+ 1e-5).
82
82
  enable_hlfb (bool): Whether to convert this normalization into a single
83
83
  op.
84
84
  """
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
112
112
 
113
113
  class LayerNorm(torch.nn.Module):
114
114
 
115
- def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ eps: float = 1e-5,
119
+ enable_hlfb: bool = False,
120
+ use_input_shape: bool = True,
121
+ ):
116
122
  """Initialize the LayerNorm layer.
117
123
 
118
124
  Args:
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
121
127
  1e-6).
122
128
  enable_hlfb (bool): Whether to convert this normalization into a single
123
129
  op.
130
+ use_input_shape (bool): Whether to use the input shape to determine the
131
+ dimension of normalization (default: True).
124
132
  """
125
133
  super().__init__()
126
134
  self.enable_hlfb = enable_hlfb
135
+ self.use_input_shape = use_input_shape
127
136
  self.eps = eps
128
137
  self.weight = torch.nn.Parameter(torch.ones(dim))
129
138
  self.bias = torch.nn.Parameter(torch.ones(dim))
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
139
148
  """
140
149
  if self.enable_hlfb:
141
150
  return layer_norm_with_hlfb(
142
- x,
143
- self.weight,
144
- self.bias,
145
- self.eps,
151
+ x, self.weight, self.bias, self.eps, self.use_input_shape
146
152
  )
153
+
154
+ if self.use_input_shape:
155
+ normalized_shape = x.shape
156
+ weight = self.weight.broadcast_to(x.shape)
157
+ bias = self.bias.broadcast_to(x.shape)
147
158
  else:
148
- return F.layer_norm(
149
- x,
150
- x.shape,
151
- self.weight.broadcast_to(x.shape),
152
- self.bias.broadcast_to(x.shape),
153
- self.eps,
154
- )
159
+ normalized_shape = self.weight.shape
160
+ weight = self.weight
161
+ bias = self.bias
162
+ return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
155
163
 
156
164
 
157
165
  def group_norm_with_hlfb(
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
193
201
  w: torch.Tensor,
194
202
  b: torch.Tensor,
195
203
  eps: float,
204
+ use_input_shape: bool,
196
205
  ):
197
206
  """Layer Normalization with high-level function boundary enabled.
198
207
 
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
201
210
  w (torch.Tensor): The weight tensor for the normalization.
202
211
  b (torch.Tensor): The bias tensor for the normalization.
203
212
  eps (float): A small float value to ensure numerical stability.
213
+ use_input_shape (bool): Whether to use the input shape to determine the
214
+ dimension of normalization.
204
215
 
205
216
  Returns:
206
217
  The output tensor of Layer Normalization.
207
218
  """
208
219
  builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
220
  x, w, b = builder.mark_inputs(x, w, b)
210
- y = F.layer_norm(
211
- x,
212
- x.shape,
213
- weight=w.broadcast_to(x.shape),
214
- bias=b.broadcast_to(x.shape),
215
- eps=eps,
216
- )
221
+ if use_input_shape:
222
+ normalized_shape = x.shape
223
+ w = w.broadcast_to(x.shape)
224
+ b = b.broadcast_to(x.shape)
225
+ else:
226
+ normalized_shape = w.shape
227
+ y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
217
228
  y = builder.mark_outputs(y)
218
229
  return y
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
119
119
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
120
120
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
121
121
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
122
- y = F.scaled_dot_product_attention(
123
- q,
124
- k,
125
- v,
126
- attn_mask=mask,
127
- dropout_p=0.0,
128
- is_causal=mask is None,
129
- scale=scale,
130
- )
122
+ if softcap is None:
123
+ y = F.scaled_dot_product_attention(
124
+ q,
125
+ k,
126
+ v,
127
+ attn_mask=mask,
128
+ dropout_p=0.0,
129
+ is_causal=mask is None,
130
+ scale=scale,
131
+ )
132
+ else:
133
+ q.mul_(scale)
134
+ scores = q @ k.transpose(-1, -2)
135
+ scores = scores / softcap
136
+ scores = torch.tanh(scores)
137
+ scores = scores * softcap
138
+ scores = scores + mask
139
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
140
+ y = torch.matmul(out, v)
131
141
 
132
142
  result = y.transpose(1, 2)
133
143
  result = builder.mark_outputs(result)
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
41
41
  )
42
42
  self.conv_1 = nn.Conv2d(
43
43
  config.in_channels,
44
- config.out_channels,
44
+ config.hidden_channels,
45
45
  kernel_size=3,
46
46
  stride=1,
47
47
  padding=1,
48
48
  )
49
49
  if config.time_embedding_channels is not None:
50
50
  self.time_emb_proj = nn.Linear(
51
- config.time_embedding_channels, config.out_channels
51
+ config.time_embedding_channels, config.hidden_channels
52
52
  )
53
53
  else:
54
54
  self.time_emb_proj = None
55
55
  self.norm_2 = layers_builder.build_norm(
56
- config.out_channels, config.normalization_config
56
+ config.hidden_channels, config.normalization_config
57
57
  )
58
58
  self.conv_2 = nn.Conv2d(
59
- config.out_channels,
59
+ config.hidden_channels,
60
60
  config.out_channels,
61
61
  kernel_size=3,
62
62
  stride=1,
@@ -391,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
391
391
  ResidualBlock2D(
392
392
  unet_cfg.ResidualBlock2DConfig(
393
393
  in_channels=input_channels,
394
+ hidden_channels=config.out_channels,
394
395
  out_channels=config.out_channels,
395
396
  time_embedding_channels=config.time_embedding_channels,
396
397
  normalization_config=config.normalization_config,
@@ -492,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
492
493
  ResidualBlock2D(
493
494
  unet_cfg.ResidualBlock2DConfig(
494
495
  in_channels=input_channels,
496
+ hidden_channels=config.out_channels,
495
497
  out_channels=config.out_channels,
496
498
  time_embedding_channels=config.time_embedding_channels,
497
499
  normalization_config=config.normalization_config,
@@ -602,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
602
604
  ResidualBlock2D(
603
605
  unet_cfg.ResidualBlock2DConfig(
604
606
  in_channels=resnet_in_channels + res_skip_channels,
607
+ hidden_channels=config.out_channels,
605
608
  out_channels=config.out_channels,
606
609
  time_embedding_channels=config.time_embedding_channels,
607
610
  normalization_config=config.normalization_config,
@@ -706,6 +709,7 @@ class MidBlock2D(nn.Module):
706
709
  ResidualBlock2D(
707
710
  unet_cfg.ResidualBlock2DConfig(
708
711
  in_channels=config.in_channels,
712
+ hidden_channels=config.in_channels,
709
713
  out_channels=config.in_channels,
710
714
  time_embedding_channels=config.time_embedding_channels,
711
715
  normalization_config=config.normalization_config,
@@ -724,6 +728,7 @@ class MidBlock2D(nn.Module):
724
728
  ResidualBlock2D(
725
729
  unet_cfg.ResidualBlock2DConfig(
726
730
  in_channels=config.in_channels,
731
+ hidden_channels=config.in_channels,
727
732
  out_channels=config.in_channels,
728
733
  time_embedding_channels=config.time_embedding_channels,
729
734
  normalization_config=config.normalization_config,
@@ -48,6 +48,7 @@ class DownSamplingConfig:
48
48
  @dataclasses.dataclass
49
49
  class ResidualBlock2DConfig:
50
50
  in_channels: int
51
+ hidden_channels: int
51
52
  out_channels: int
52
53
  normalization_config: layers_cfg.NormalizationConfig
53
54
  activation_config: layers_cfg.ActivationConfig
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
96
96
  def test_gemma2(self):
97
97
  config = gemma2.get_fake_model_config()
98
98
  pytorch_model = gemma2.Gemma2(config).eval()
99
- self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
100
100
 
101
101
  @googletest.skipIf(
102
102
  ai_edge_config.Config.use_torch_xla,
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
412
412
  ):
413
413
  residual_block_config = unet_config.ResidualBlock2DConfig(
414
414
  in_channels=config.in_channels,
415
+ hidden_channels=config.in_channels,
415
416
  out_channels=config.in_channels,
416
417
  time_embedding_channels=config.time_embedding_channels,
417
418
  normalization_config=config.normalization_config,
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
466
467
  f"{converted_state_param_prefix}.resnets.{i}",
467
468
  unet_config.ResidualBlock2DConfig(
468
469
  in_channels=input_channels,
470
+ hidden_channels=config.out_channels,
469
471
  out_channels=config.out_channels,
470
472
  time_embedding_channels=config.time_embedding_channels,
471
473
  normalization_config=config.normalization_config,
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
508
510
  f"{converted_state_param_prefix}.resnets.{i}",
509
511
  unet_config.ResidualBlock2DConfig(
510
512
  in_channels=input_channels,
513
+ hidden_channels=config.out_channels,
511
514
  out_channels=config.out_channels,
512
515
  time_embedding_channels=config.time_embedding_channels,
513
516
  normalization_config=config.normalization_config,
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
554
557
  f"{converted_state_param_prefix}.resnets.{i}",
555
558
  unet_config.ResidualBlock2DConfig(
556
559
  in_channels=resnet_in_channels + res_skip_channels,
560
+ hidden_channels=config.out_channels,
557
561
  out_channels=config.out_channels,
558
562
  time_embedding_channels=config.time_embedding_channels,
559
563
  normalization_config=config.normalization_config,
@@ -0,0 +1,200 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions to verify the reauthored models."""
17
+
18
+ import datetime
19
+ from typing import List
20
+
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ def log_msg(*args):
27
+ print("[%s]" % datetime.datetime.now(), *args)
28
+
29
+
30
+ def forward(
31
+ model: torch.nn.Module,
32
+ tokens: torch.Tensor,
33
+ kv_cache: kv_utils.KVCache,
34
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
35
+ """Forwards the model reauthored with ai_edge_torch Generative API.
36
+
37
+ Args:
38
+ model (torch.nn.Module): The model to forward. It should be a model built
39
+ with ai_edge_torch Generative API.
40
+ tokens (torch.Tensor): The input tokens to forward.
41
+ kv_cache (KVCache): The KV cache to forward.
42
+
43
+ Returns:
44
+ The output logits and the updated KV cache.
45
+ """
46
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
47
+ output = model.forward(tokens, input_pos, kv_cache)
48
+ return output["logits"], output["kv_cache"]
49
+
50
+
51
+ def generate(
52
+ model: torch.nn.Module, prompts: torch.Tensor, response_len: int
53
+ ) -> torch.Tensor:
54
+ """Generates the response to the prompts.
55
+
56
+ It appends tokens output by the model to the prompts and feeds them back to
57
+ the model up to decode_len.
58
+
59
+ Args:
60
+ model (torch.nn.Module): The model to generate. It should be a model built
61
+ with ai_edge_torch Generative API.
62
+ prompts (torch.Tensor): The prompts to generate.
63
+ response_len (int): The number of tokens to generate.
64
+
65
+ Returns:
66
+ The generated tokens.
67
+ """
68
+ input_ids = prompts[0].int().tolist()
69
+ kv_cache = kv_utils.KVCache.from_model_config(model.config)
70
+ for _ in range(response_len - len(input_ids)):
71
+ logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
72
+ generated_token = logits[0][-1].argmax().item()
73
+ input_ids.append(generated_token)
74
+ return torch.tensor([input_ids])
75
+
76
+
77
+ def verify_with_input_ids(
78
+ original_model: torch.nn.Module,
79
+ reauthored_model: torch.nn.Module,
80
+ input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
81
+ kv_cache_max_len: int = 1024,
82
+ rtol: float = 1e-05,
83
+ atol: float = 1e-05,
84
+ ) -> bool:
85
+ """Verifies if the model reauthored generates the same output of the oringal.
86
+
87
+ It compares only one outputs from the original and the reauthored model.
88
+
89
+ Args:
90
+ original_model (torch.nn.Module): The original model.
91
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
92
+ Generative API.
93
+ input_ids (torch.Tensor): The input token IDs to forward.
94
+ kv_cache_max_len (int): The maximum sequence length of the KV cache.
95
+ rtol (float): The relative tolerance for the comparison.
96
+ atol (float): The absolute tolerance for the comparison.
97
+
98
+ Returns:
99
+ True if the model reauthored generates the same output of the original.
100
+ """
101
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
102
+ input_ids_len = input_ids.shape[1]
103
+ tokens[0, :input_ids_len] = input_ids
104
+
105
+ log_msg("Forwarding the original model...")
106
+ outputs_original = original_model.forward(tokens)
107
+ logits_original = outputs_original.logits[0, input_ids_len - 1, :]
108
+ log_msg("logits_original: ", logits_original)
109
+
110
+ log_msg("Forwarding the reauthored model...")
111
+ kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
112
+ outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
113
+ logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
114
+ log_msg("logits_reauthored:", logits_reauthored)
115
+
116
+ return torch.allclose(
117
+ logits_original, logits_reauthored, rtol=rtol, atol=atol
118
+ )
119
+
120
+
121
+ def verify_model_with_prompts(
122
+ original_model: torch.nn.Module,
123
+ reauthored_model: torch.nn.Module,
124
+ tokenizer: torch.nn.Module,
125
+ prompts: str,
126
+ ) -> bool:
127
+ """Verifies if the model reauthored generates the same answer of the oringal.
128
+
129
+ It compares an answer, i.e. multiple continuous outputs generated by the
130
+ original and the reauthored model.
131
+
132
+ Args:
133
+ original_model (torch.nn.Module): The original model.
134
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
135
+ Generative API.
136
+ tokenizer (torch.nn.Module): The tokenizer.
137
+ prompts (str): The input prompts to generate answers.
138
+
139
+ Returns:
140
+ True if the model reauthored generates the same answer of the original.
141
+ """
142
+ prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
143
+
144
+ log_msg("Generating answer with the original model...")
145
+ outputs_original = original_model.generate(prompt_tokens)
146
+ response_original = tokenizer.decode(outputs_original[0])
147
+ log_msg("outputs_from_original_model: [[", response_original, "]]")
148
+
149
+ log_msg("Generating answer with the reauthored model...")
150
+ generate_len = len(outputs_original[0])
151
+ outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
152
+ response_reauthored = tokenizer.decode(outputs_reauthored[0])
153
+ log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
154
+
155
+ return response_original == response_reauthored
156
+
157
+
158
+ def verify_reauthored_model(
159
+ original_model: torch.nn.Module,
160
+ reauthored_model: torch.nn.Module,
161
+ tokenizer: torch.nn.Module,
162
+ prompts: List[str],
163
+ rtol: float = 1e-05,
164
+ atol: float = 1e-05,
165
+ ):
166
+ """Verifies the reauthored model against the original model.
167
+
168
+ It verifies the reauthored model with two methods:
169
+ 1. It compares the output of the original and the reauthored model with an
170
+ arbitrary input.
171
+ 2. It compares the answer generated by the original and the reauthored model
172
+ with a prompt.
173
+
174
+ It prints out "PASS" or "FAILED" to the console.
175
+
176
+ Args:
177
+ original_model (torch.nn.Module): The original model.
178
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
179
+ Generative API.
180
+ tokenizer (torch.nn.Module): The tokenizer.
181
+ prompts (List[str]): List of the input prompts to generate answers.
182
+ rtol (float): The relative tolerance for the comparison.
183
+ atol (float): The absolute tolerance for the comparison.
184
+ """
185
+ log_msg("Verifying the reauthored model with an arbitrary input...")
186
+ if verify_with_input_ids(
187
+ original_model, reauthored_model, rtol=rtol, atol=atol
188
+ ):
189
+ log_msg("PASS")
190
+ else:
191
+ log_msg("FAILED")
192
+
193
+ for p in prompts:
194
+ log_msg("Verifying the reauthored model with prompts:", p)
195
+ if verify_model_with_prompts(
196
+ original_model, reauthored_model, tokenizer, p
197
+ ):
198
+ log_msg("PASS")
199
+ else:
200
+ log_msg("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240918"
16
+ __version__ = "0.3.0.dev20240919"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240918
3
+ Version: 0.3.0.dev20240919
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=jWg5qA8V0XqgFoqjk0SCsNWPRBeTmfrir9u0bucHYOU,706
6
+ ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -45,13 +45,16 @@ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2
45
45
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
46
46
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
48
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
48
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
49
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
49
50
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
51
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
51
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
52
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
53
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
52
54
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
55
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
54
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
56
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
57
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
55
58
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
59
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
57
60
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
@@ -76,23 +79,24 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71Wvv
76
79
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
77
80
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
81
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
79
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
82
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
83
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
80
84
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
81
85
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
82
86
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
87
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
84
88
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
85
- ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
89
+ ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
86
90
  ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
87
91
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
88
- ai_edge_torch/generative/layers/model_config.py,sha256=zV3pA7giuKPrQdH81dpZz8D6LfGD-1YHuXuhIlypKc0,6784
89
- ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
92
+ ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
93
+ ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
90
94
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
91
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
95
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
92
96
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
97
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
94
98
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
95
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
99
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
96
100
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
101
  ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
98
102
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -104,14 +108,15 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
104
108
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
105
109
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
106
110
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
107
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5z9dvspFKO74Y_qIJ_VK0MYUoPdRf82Y,4498
111
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
108
112
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
109
113
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
110
114
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
111
115
  ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
112
116
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
113
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
117
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
114
118
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
119
+ ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
115
120
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
116
121
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
117
122
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -158,8 +163,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
158
163
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
159
164
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
160
165
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
161
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
162
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/METADATA,sha256=dMaIr8Iny84IfNGQGSrtlTGkYlH_mAMmgvGWm5-pkxM,1859
163
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
164
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
165
- ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/RECORD,,
166
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
167
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
168
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
169
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
170
+ ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,