ai-edge-torch-nightly 0.3.0.dev20240920__py3-none-any.whl → 0.3.0.dev20240923__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py} +3 -3
  2. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +3 -36
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -26
  4. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +55 -0
  5. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +55 -0
  6. ai_edge_torch/generative/examples/gemma/verify_util.py +142 -0
  7. ai_edge_torch/generative/examples/openelm/verify.py +1 -1
  8. ai_edge_torch/generative/examples/phi/verify.py +1 -1
  9. ai_edge_torch/generative/examples/smollm/verify.py +1 -1
  10. ai_edge_torch/generative/examples/tiny_llama/verify.py +1 -1
  11. ai_edge_torch/generative/layers/feed_forward.py +0 -1
  12. ai_edge_torch/generative/quantize/example.py +3 -3
  13. ai_edge_torch/generative/test/test_model_conversion_large.py +4 -4
  14. ai_edge_torch/generative/utilities/verifier.py +21 -19
  15. ai_edge_torch/version.py +1 -1
  16. {ai_edge_torch_nightly-0.3.0.dev20240920.dist-info → ai_edge_torch_nightly-0.3.0.dev20240923.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240920.dist-info → ai_edge_torch_nightly-0.3.0.dev20240923.dist-info}/RECORD +20 -17
  18. {ai_edge_torch_nightly-0.3.0.dev20240920.dist-info → ai_edge_torch_nightly-0.3.0.dev20240923.dist-info}/LICENSE +0 -0
  19. {ai_edge_torch_nightly-0.3.0.dev20240920.dist-info → ai_edge_torch_nightly-0.3.0.dev20240923.dist-info}/WHEEL +0 -0
  20. {ai_edge_torch_nightly-0.3.0.dev20240920.dist-info → ai_edge_torch_nightly-0.3.0.dev20240923.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a Gemma model to multi-signature tflite model."""
16
+ """Example of converting a Gemma1 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  from absl import app
22
22
  from absl import flags
23
- from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.examples.gemma import gemma1
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
 
26
26
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -51,7 +51,7 @@ _QUANTIZE = flags.DEFINE_bool(
51
51
 
52
52
 
53
53
  def main(_):
54
- pytorch_model = gemma.build_2b_model(
54
+ pytorch_model = gemma1.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
57
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of building a Gemma model."""
17
-
18
- import os
19
- import pathlib
16
+ """Example of building a Gemma1 model."""
20
17
 
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
@@ -24,7 +21,6 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -32,13 +28,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
28
  ff_up_proj="model.layers.{}.mlp.up_proj",
33
29
  ff_down_proj="model.layers.{}.mlp.down_proj",
34
30
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
35
- attn_query_proj="model.layers.{}.self_attn.q_proj",
36
- attn_key_proj="model.layers.{}.self_attn.k_proj",
37
- attn_value_proj="model.layers.{}.self_attn.v_proj",
31
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
38
32
  attn_output_proj="model.layers.{}.self_attn.o_proj",
39
33
  pre_attn_norm="model.layers.{}.input_layernorm",
40
34
  post_attn_norm="model.layers.{}.post_attention_layernorm",
41
- embedding="model.embed_tokens",
35
+ embedding="embedder",
42
36
  final_norm="model.norm",
43
37
  lm_head=None,
44
38
  )
@@ -192,30 +186,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
192
186
  loader.load(model, strict=False)
193
187
  model.eval()
194
188
  return model
195
-
196
-
197
- def define_and_run_2b(checkpoint_path: str) -> None:
198
- """Instantiates and runs a Gemma 2B model."""
199
-
200
- current_dir = pathlib.Path(__file__).parent.resolve()
201
- gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
202
-
203
- kv_cache_max_len = 1024
204
- model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
205
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
206
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
207
- tokens[0, :4] = idx
208
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
- kv = kv_utils.KVCache.from_model_config(model.config)
210
- output = model.forward(tokens, input_pos, kv)
211
- print("comparing with goldens..")
212
- assert torch.allclose(
213
- gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
214
- )
215
-
216
-
217
- if __name__ == "__main__":
218
- input_checkpoint_path = os.path.join(
219
- pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
- )
221
- define_and_run_2b(input_checkpoint_path)
@@ -267,29 +267,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
267
267
  loader.load(model, strict=False)
268
268
  model.eval()
269
269
  return model
270
-
271
-
272
- def define_and_run_2b(checkpoint_path: str) -> None:
273
- """Instantiates and runs a Gemma2 2B model."""
274
-
275
- current_dir = pathlib.Path(__file__).parent.resolve()
276
- gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
277
- print("Running GEMMA 2")
278
- kv_cache_max_len = 1024
279
- model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
280
- toks = torch.from_numpy(
281
- np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
282
- )
283
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
284
- tokens[0, :9] = toks
285
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
286
- kv = kv_utils.KVCache.from_model_config(model.config)
287
- out = model.forward(tokens, input_pos, kv)
288
- out_final = out["logits"][0, 8, :]
289
- assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
290
-
291
-
292
- if __name__ == "__main__":
293
- torch.set_printoptions(sci_mode=True)
294
- path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
295
- define_and_run_2b(path)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma1 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
+ from ai_edge_torch.generative.examples.gemma import verify_util
22
+ from ai_edge_torch.generative.utilities import verifier
23
+ import kagglehub
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
+ "max_new_tokens",
32
+ 30,
33
+ "The maximum size of the generated tokens.",
34
+ )
35
+
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
39
+
40
+ verifier.log_msg("Building the reauthored model from", checkpoint)
41
+ reauthored_model = gemma1.build_2b_model(checkpoint)
42
+
43
+ verify_util.verify_reauthored_gemma_model(
44
+ checkpoint=checkpoint,
45
+ variant="2b",
46
+ reauthored_model=reauthored_model,
47
+ weight_filename="gemma-2b-it.ckpt",
48
+ generate_prompts=_PROMPTS.value,
49
+ forward_input_ids=[[1, 2, 3, 4]],
50
+ max_new_tokens=_MAX_NEW_TOKENS.value,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ app.run(main)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.examples.gemma import verify_util
22
+ from ai_edge_torch.generative.utilities import verifier
23
+ import kagglehub
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
+ "max_new_tokens",
32
+ 30,
33
+ "The maximum size of the generated tokens.",
34
+ )
35
+
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
39
+
40
+ verifier.log_msg("Building the reauthored model from", checkpoint)
41
+ reauthored_model = gemma2.build_2b_model(checkpoint)
42
+
43
+ verify_util.verify_reauthored_gemma_model(
44
+ checkpoint=checkpoint,
45
+ variant="2b-v2",
46
+ reauthored_model=reauthored_model,
47
+ generate_prompts=_PROMPTS.value,
48
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
49
+ max_new_tokens=_MAX_NEW_TOKENS.value,
50
+ atol=1e-04,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ app.run(main)
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions to verify the reauthored Gemma model."""
17
+
18
+ import dataclasses
19
+ import os
20
+ from typing import List, Tuple
21
+
22
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ from gemma import config as gemma_config
25
+ from gemma import model as gemma_model
26
+ import torch
27
+
28
+
29
+ @dataclasses.dataclass
30
+ class _Output:
31
+ logits: torch.Tensor
32
+
33
+
34
+ class GemmaWrapper(verifier.ModelWrapper):
35
+ """Gemma model wrapper for verification.
36
+
37
+ Verifier calls model.forward() with maxium sequence length (1024) expecting
38
+ the output has 'logits' field while Gemma gets the input tokens with the
39
+ actual length and returns logits in a tuple.
40
+
41
+ Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
42
+ inside model.generate().
43
+ """
44
+
45
+ def __init__(self, model: torch.nn.Module, max_new_tokens: int):
46
+ super().__init__(model)
47
+ self.max_new_tokens = max_new_tokens
48
+
49
+ def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
50
+ for i in range(tokens.shape[1]):
51
+ if tokens[0, i] == 0:
52
+ return i
53
+ return tokens.shape[1]
54
+
55
+ def _get_kv_caches(
56
+ self, max_seq_len: int
57
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
58
+ config = self.model.config
59
+ cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
60
+ cache = torch.zeros(cache_size)
61
+ return [
62
+ (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
63
+ ]
64
+
65
+ def forward(self, tokens: torch.Tensor) -> _Output:
66
+ """Forwards the model after reducing input tokens to the actual length."""
67
+ actual_input_len = self._get_actual_input_len(tokens)
68
+ input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
69
+ mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
70
+ _, logits = self.model.forward(
71
+ input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
72
+ input_positions=input_pos,
73
+ kv_write_indices=None,
74
+ kv_caches=self._get_kv_caches(tokens.shape[1]),
75
+ mask=mask_cache.index_select(2, input_pos),
76
+ output_positions=input_pos,
77
+ temperatures=None,
78
+ top_ps=torch.tensor([1.0], dtype=torch.float),
79
+ top_ks=torch.tensor([1], dtype=torch.long),
80
+ )
81
+ return _Output(logits.float())
82
+
83
+ def generate(self, tokens: torch.Tensor) -> torch.Tensor:
84
+ """Generates the response after decoding the tokens into a string."""
85
+ prompts = self.model.tokenizer.decode(tokens[0].tolist())
86
+ response = self.model.generate(
87
+ prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
88
+ )
89
+ return torch.tensor([self.model.tokenizer.encode(prompts + response)])
90
+
91
+
92
+ class TokenizerWrapper(torch.nn.Module):
93
+ """Tokenizer wrapper for verification.
94
+
95
+ Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
96
+ tokenizer expects tokens in a list.
97
+ """
98
+
99
+ def __init__(self, tokenizer: torch.nn.Module):
100
+ super().__init__()
101
+ self.tokenizer = tokenizer
102
+
103
+ def encode(self, text: str, **_) -> torch.Tensor:
104
+ """Adds one more dimension to the output of the tokenizer."""
105
+ return torch.tensor([self.tokenizer.encode(text)])
106
+
107
+ def decode(self, tokens: torch.Tensor) -> str:
108
+ """Decodes the token sequence after converting to a list."""
109
+ return self.tokenizer.decode(tokens.tolist())
110
+
111
+
112
+ def verify_reauthored_gemma_model(
113
+ checkpoint: str,
114
+ variant: str,
115
+ reauthored_model: torch.nn.Module,
116
+ generate_prompts: List[str],
117
+ forward_input_ids: List[List[int]],
118
+ weight_filename: str = "model.ckpt",
119
+ tokenizer_filename: str = "tokenizer.model",
120
+ max_new_tokens: int = 20,
121
+ rtol: float = 1e-05,
122
+ atol: float = 1e-05,
123
+ ):
124
+ """Verifies the reauthored Gemma model against the original model."""
125
+ config = gemma_config.get_model_config(variant)
126
+ config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
127
+ # Use float32 to be compatible with the reauthored model.
128
+ config.dtype = torch.float32
129
+
130
+ verifier.log_msg("Loading the original model from", checkpoint)
131
+ original_model = gemma_model.GemmaForCausalLM(config).eval()
132
+ original_model.load_weights(os.path.join(checkpoint, weight_filename))
133
+
134
+ verifier.verify_reauthored_model(
135
+ original_model=GemmaWrapper(original_model, max_new_tokens),
136
+ reauthored_model=reauthored_model,
137
+ tokenizer=TokenizerWrapper(original_model.tokenizer),
138
+ generate_prompts=generate_prompts,
139
+ forward_input_ids=forward_input_ids,
140
+ rtol=rtol,
141
+ atol=atol,
142
+ )
@@ -55,7 +55,7 @@ def main(_):
55
55
  original_model=wrapper_model,
56
56
  reauthored_model=reauthored_model,
57
57
  tokenizer=tokenizer,
58
- prompts=_PROMPTS.value,
58
+ generate_prompts=_PROMPTS.value,
59
59
  )
60
60
 
61
61
 
@@ -54,7 +54,7 @@ def main(_):
54
54
  original_model=wrapper_model,
55
55
  reauthored_model=reauthored_model,
56
56
  tokenizer=tokenizer,
57
- prompts=_PROMPTS.value,
57
+ generate_prompts=_PROMPTS.value,
58
58
  atol=1e-03,
59
59
  )
60
60
 
@@ -51,7 +51,7 @@ def main(_):
51
51
  original_model=wrapper_model,
52
52
  reauthored_model=reauthored_model,
53
53
  tokenizer=tokenizer,
54
- prompts=_PROMPTS.value,
54
+ generate_prompts=_PROMPTS.value,
55
55
  atol=1e-04,
56
56
  )
57
57
 
@@ -53,7 +53,7 @@ def main(_):
53
53
  original_model=wrapper_model,
54
54
  reauthored_model=reauthored_model,
55
55
  tokenizer=tokenizer,
56
- prompts=_PROMPTS.value,
56
+ generate_prompts=_PROMPTS.value,
57
57
  atol=1e-04,
58
58
  )
59
59
 
@@ -18,7 +18,6 @@ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
20
  from torch import nn
21
- import torch.nn.functional as F
22
21
 
23
22
 
24
23
  class SequentialFeedForward(nn.Module):
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import ai_edge_torch
17
- from ai_edge_torch.generative.examples.gemma import gemma
17
+ from ai_edge_torch.generative.examples.gemma import gemma1
18
18
  from ai_edge_torch.generative.quantize import quant_recipes
19
19
  import numpy as np
20
20
  import torch
@@ -22,8 +22,8 @@ import torch
22
22
 
23
23
  def main():
24
24
  # Build a PyTorch model as usual
25
- config = gemma.get_fake_model_config()
26
- model = gemma.Gemma(config)
25
+ config = gemma1.get_fake_model_config()
26
+ model = gemma1.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
28
  tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
29
  tokens[0, :4] = idx
@@ -17,7 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
23
  from ai_edge_torch.generative.examples.phi import phi2
@@ -82,9 +82,9 @@ class TestModelConversion(googletest.TestCase):
82
82
  ai_edge_config.Config.use_torch_xla,
83
83
  reason="tests with custom ops are not supported on oss",
84
84
  )
85
- def test_gemma(self):
86
- config = gemma.get_fake_model_config()
87
- pytorch_model = gemma.Gemma(config).eval()
85
+ def test_gemma1(self):
86
+ config = gemma1.get_fake_model_config()
87
+ pytorch_model = gemma1.Gemma(config).eval()
88
88
  self._test_model(
89
89
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
90
90
  )
@@ -19,7 +19,6 @@ import datetime
19
19
  from typing import List, Optional, Union
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
- import numpy as np
23
22
  import torch
24
23
  import transformers
25
24
 
@@ -126,7 +125,7 @@ def generate(
126
125
  def verify_with_input_ids(
127
126
  original_model: ModelWrapper,
128
127
  reauthored_model: torch.nn.Module,
129
- input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
128
+ input_ids: List[int],
130
129
  kv_cache_max_len: int = 1024,
131
130
  rtol: float = 1e-05,
132
131
  atol: float = 1e-05,
@@ -139,7 +138,7 @@ def verify_with_input_ids(
139
138
  original_model (ModelWrapper): The original model.
140
139
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
141
140
  Generative API.
142
- input_ids (torch.Tensor): The input token IDs to forward.
141
+ input_ids (List[int]): The input token IDs to forward with.
143
142
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
144
143
  rtol (float): The relative tolerance for the comparison.
145
144
  atol (float): The absolute tolerance for the comparison.
@@ -148,18 +147,17 @@ def verify_with_input_ids(
148
147
  True if the model reauthored generates the same output of the original.
149
148
  """
150
149
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
151
- input_ids_len = input_ids.shape[1]
152
- tokens[0, :input_ids_len] = input_ids
150
+ tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
153
151
 
154
152
  log_msg("Forwarding the original model...")
155
153
  outputs_original = original_model.forward(tokens)
156
- logits_original = outputs_original.logits[0, input_ids_len - 1, :]
154
+ logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
157
155
  log_msg("logits_original: ", logits_original)
158
156
 
159
157
  log_msg("Forwarding the reauthored model...")
160
158
  kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
161
159
  outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
162
- logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
160
+ logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
163
161
  log_msg("logits_reauthored:", logits_reauthored)
164
162
 
165
163
  return torch.allclose(
@@ -208,7 +206,8 @@ def verify_reauthored_model(
208
206
  original_model: ModelWrapper,
209
207
  reauthored_model: torch.nn.Module,
210
208
  tokenizer: torch.nn.Module,
211
- prompts: List[str],
209
+ generate_prompts: List[str],
210
+ forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
212
211
  rtol: float = 1e-05,
213
212
  atol: float = 1e-05,
214
213
  ):
@@ -227,22 +226,25 @@ def verify_reauthored_model(
227
226
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
228
227
  Generative API.
229
228
  tokenizer (torch.nn.Module): The tokenizer.
230
- prompts (List[str]): List of the input prompts to generate answers.
229
+ generate_prompts (List[str]): List of the input prompts to generate answers.
230
+ forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
231
+ forward with.
231
232
  rtol (float): The relative tolerance for the comparison.
232
233
  atol (float): The absolute tolerance for the comparison.
233
234
  """
234
- log_msg("Verifying the reauthored model with an arbitrary input...")
235
- if verify_with_input_ids(
236
- original_model, reauthored_model, rtol=rtol, atol=atol
237
- ):
238
- log_msg("PASS")
239
- else:
240
- log_msg("FAILED")
235
+ for input_ids in forward_input_ids:
236
+ log_msg("Verifying the reauthored model with input IDs:", input_ids)
237
+ if verify_with_input_ids(
238
+ original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
239
+ ):
240
+ log_msg("PASS")
241
+ else:
242
+ log_msg("FAILED")
241
243
 
242
- for p in prompts:
243
- log_msg("Verifying the reauthored model with prompts:", p)
244
+ for prompts in generate_prompts:
245
+ log_msg("Verifying the reauthored model with prompts:", prompts)
244
246
  if verify_model_with_prompts(
245
- original_model, reauthored_model, tokenizer, p
247
+ original_model, reauthored_model, tokenizer, prompts
246
248
  ):
247
249
  log_msg("PASS")
248
250
  else:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240920"
16
+ __version__ = "0.3.0.dev20240923"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240920
3
+ Version: 0.3.0.dev20240923
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=-oH0R07HZpydzqltOWclHB1dbcc4VycTlZcnDYtS89g,706
6
+ ai_edge_torch/version.py,sha256=oxtOOEY9LJkV5vRrgr1EoSjAjuetYVNq7WQqMuauRkc,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -39,22 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
39
39
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
42
43
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
43
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=t8Qg10obnEzeoMeyHnZhyNBN7G85SGy-au8Y8nehq8E,2181
44
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=KsM6VlzluTqbodG24IFr3biPxBrLay0z0gmnG0bcU2U,9277
46
+ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=B14IR4mIw6qBVUbiIRdfdUzHMCIJCJ0RFPsYOxA46qc,1776
47
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax1m4WTAnC3gMxBpcvGA7-xTO1Iuw,1802
48
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
46
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
50
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
48
51
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
49
- ai_edge_torch/generative/examples/openelm/verify.py,sha256=BvK4c8jodQBy2l3NnvCjlBB0qaA7EYwPNKklvFR4k_o,2103
52
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
50
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
54
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
52
55
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
53
- ai_edge_torch/generative/examples/phi/verify.py,sha256=5bKONolW8JIsQAzMHIvh_OSytoJVVJqDZEcxjhciFnI,2136
56
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
54
57
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
58
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
56
59
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
57
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=wsoy3CaHZhrdJjkJJYir7xxxwgCvLprMnh8QxT0hEkc,2013
60
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
58
61
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
62
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
60
63
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
@@ -80,14 +83,14 @@ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
80
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
84
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
82
85
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
83
- ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=27oBf706_AKX7amfp2THF9J0G3AUEEecGaXv025idKA,2086
86
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LUChL5tA7FHL_DlTg5QKvGInmH9AwVVw9a-omcndiz8,2095
84
87
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
85
88
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
86
89
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
90
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
88
91
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
89
92
  ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
90
- ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
93
+ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
91
94
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
92
95
  ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
93
96
  ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
@@ -98,7 +101,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW
98
101
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
99
102
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
100
103
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
101
- ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
104
+ ai_edge_torch/generative/quantize/example.py,sha256=tlACaRsz6lqOxakzpXVFJZYfFKOiFqetcYVJqWVRdPE,1542
102
105
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
103
106
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
104
107
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
@@ -108,7 +111,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
108
111
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
109
112
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
110
113
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
111
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=PtePuBqVMLjxq2cDIIXXqaz7zsn3R19oilFyIVJRFi8,4490
114
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
112
115
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
113
116
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
114
117
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -116,7 +119,7 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
116
119
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
117
120
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
118
121
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
119
- ai_edge_torch/generative/utilities/verifier.py,sha256=7DoYtkilz4wjWnXfdydIGNgTG1udZIydFxdbpIcKbMQ,8625
122
+ ai_edge_torch/generative/utilities/verifier.py,sha256=vU9KgmFS7I9jNS_3H2SWROx-rbNqtMKgQC2MRhdqQ4g,8803
120
123
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
121
124
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
122
125
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -163,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
163
166
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
164
167
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
165
168
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
166
- ai_edge_torch_nightly-0.3.0.dev20240920.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
167
- ai_edge_torch_nightly-0.3.0.dev20240920.dist-info/METADATA,sha256=m60oD-H8W2EMVolDGw02tMYcKDrotaTaLtsZwzr_Kyk,1859
168
- ai_edge_torch_nightly-0.3.0.dev20240920.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
169
- ai_edge_torch_nightly-0.3.0.dev20240920.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
170
- ai_edge_torch_nightly-0.3.0.dev20240920.dist-info/RECORD,,
169
+ ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
+ ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/METADATA,sha256=BgwLxDJ3AOPVn0fkngAQpf3YdmShufhMt3bANFevtiQ,1859
171
+ ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
+ ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
+ ai_edge_torch_nightly-0.3.0.dev20240923.dist-info/RECORD,,