ai-edge-torch-nightly 0.3.0.dev20240919__py3-none-any.whl → 0.3.0.dev20240921__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +7 -3
  2. ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py} +9 -7
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
  4. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +3 -36
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -26
  6. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +55 -0
  7. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +55 -0
  8. ai_edge_torch/generative/examples/gemma/verify_util.py +142 -0
  9. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
  10. ai_edge_torch/generative/examples/openelm/verify.py +6 -4
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
  12. ai_edge_torch/generative/examples/phi/verify.py +14 -4
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
  14. ai_edge_torch/generative/examples/smollm/verify.py +5 -4
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
  16. ai_edge_torch/generative/examples/tiny_llama/verify.py +6 -5
  17. ai_edge_torch/generative/layers/feed_forward.py +0 -1
  18. ai_edge_torch/generative/quantize/example.py +3 -3
  19. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  20. ai_edge_torch/generative/test/test_model_conversion_large.py +5 -5
  21. ai_edge_torch/generative/utilities/verifier.py +77 -26
  22. ai_edge_torch/model.py +7 -4
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/RECORD +28 -25
  26. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
23
23
  from ai_edge_torch._convert import conversion_utils
24
24
  from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
- import tensorflow as tf
27
26
  import torch
28
27
  from torch import nn
29
28
  import torchvision
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
32
32
 
33
33
 
34
34
  @dataclasses.dataclass
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
466
466
  np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
467
467
  np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
468
468
 
469
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
469
+ interpreter = tfl_interpreter.Interpreter(
470
+ model_content=edge_model._tflite_model
471
+ )
470
472
  runner = interpreter.get_signature_runner("serving_default")
471
473
  output_details = runner.get_output_details()
472
474
  self.assertIn("x", output_details.keys())
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
477
479
  def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
478
480
  model.eval()
479
481
  edge_model = ai_edge_torch.convert(model, args, kwargs)
480
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
482
+ interpreter = tfl_interpreter.Interpreter(
483
+ model_content=edge_model._tflite_model
484
+ )
481
485
  runner = interpreter.get_signature_runner("serving_default")
482
486
  input_details = runner.get_input_details()
483
487
  self.assertEqual(input_details.keys(), flat_inputs.keys())
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a Gemma model to multi-signature tflite model."""
16
+ """Example of converting a Gemma1 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  from absl import app
22
22
  from absl import flags
23
- from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.examples.gemma import gemma1
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
 
26
26
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -51,12 +51,14 @@ _QUANTIZE = flags.DEFINE_bool(
51
51
 
52
52
 
53
53
  def main(_):
54
- pytorch_model = gemma.build_2b_model(
54
+ pytorch_model = gemma1.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = gemma2.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of building a Gemma model."""
17
-
18
- import os
19
- import pathlib
16
+ """Example of building a Gemma1 model."""
20
17
 
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
@@ -24,7 +21,6 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -32,13 +28,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
28
  ff_up_proj="model.layers.{}.mlp.up_proj",
33
29
  ff_down_proj="model.layers.{}.mlp.down_proj",
34
30
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
35
- attn_query_proj="model.layers.{}.self_attn.q_proj",
36
- attn_key_proj="model.layers.{}.self_attn.k_proj",
37
- attn_value_proj="model.layers.{}.self_attn.v_proj",
31
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
38
32
  attn_output_proj="model.layers.{}.self_attn.o_proj",
39
33
  pre_attn_norm="model.layers.{}.input_layernorm",
40
34
  post_attn_norm="model.layers.{}.post_attention_layernorm",
41
- embedding="model.embed_tokens",
35
+ embedding="embedder",
42
36
  final_norm="model.norm",
43
37
  lm_head=None,
44
38
  )
@@ -192,30 +186,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
192
186
  loader.load(model, strict=False)
193
187
  model.eval()
194
188
  return model
195
-
196
-
197
- def define_and_run_2b(checkpoint_path: str) -> None:
198
- """Instantiates and runs a Gemma 2B model."""
199
-
200
- current_dir = pathlib.Path(__file__).parent.resolve()
201
- gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
202
-
203
- kv_cache_max_len = 1024
204
- model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
205
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
206
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
207
- tokens[0, :4] = idx
208
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
- kv = kv_utils.KVCache.from_model_config(model.config)
210
- output = model.forward(tokens, input_pos, kv)
211
- print("comparing with goldens..")
212
- assert torch.allclose(
213
- gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
214
- )
215
-
216
-
217
- if __name__ == "__main__":
218
- input_checkpoint_path = os.path.join(
219
- pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
- )
221
- define_and_run_2b(input_checkpoint_path)
@@ -267,29 +267,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
267
267
  loader.load(model, strict=False)
268
268
  model.eval()
269
269
  return model
270
-
271
-
272
- def define_and_run_2b(checkpoint_path: str) -> None:
273
- """Instantiates and runs a Gemma2 2B model."""
274
-
275
- current_dir = pathlib.Path(__file__).parent.resolve()
276
- gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
277
- print("Running GEMMA 2")
278
- kv_cache_max_len = 1024
279
- model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
280
- toks = torch.from_numpy(
281
- np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
282
- )
283
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
284
- tokens[0, :9] = toks
285
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
286
- kv = kv_utils.KVCache.from_model_config(model.config)
287
- out = model.forward(tokens, input_pos, kv)
288
- out_final = out["logits"][0, 8, :]
289
- assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
290
-
291
-
292
- if __name__ == "__main__":
293
- torch.set_printoptions(sci_mode=True)
294
- path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
295
- define_and_run_2b(path)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma1 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
+ from ai_edge_torch.generative.examples.gemma import verify_util
22
+ from ai_edge_torch.generative.utilities import verifier
23
+ import kagglehub
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
+ "max_new_tokens",
32
+ 30,
33
+ "The maximum size of the generated tokens.",
34
+ )
35
+
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
39
+
40
+ verifier.log_msg("Building the reauthored model from", checkpoint)
41
+ reauthored_model = gemma1.build_2b_model(checkpoint)
42
+
43
+ verify_util.verify_reauthored_gemma_model(
44
+ checkpoint=checkpoint,
45
+ variant="2b",
46
+ reauthored_model=reauthored_model,
47
+ weight_filename="gemma-2b-it.ckpt",
48
+ generate_prompts=_PROMPTS.value,
49
+ forward_input_ids=[[1, 2, 3, 4]],
50
+ max_new_tokens=_MAX_NEW_TOKENS.value,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ app.run(main)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.examples.gemma import verify_util
22
+ from ai_edge_torch.generative.utilities import verifier
23
+ import kagglehub
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
+ "max_new_tokens",
32
+ 30,
33
+ "The maximum size of the generated tokens.",
34
+ )
35
+
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
39
+
40
+ verifier.log_msg("Building the reauthored model from", checkpoint)
41
+ reauthored_model = gemma2.build_2b_model(checkpoint)
42
+
43
+ verify_util.verify_reauthored_gemma_model(
44
+ checkpoint=checkpoint,
45
+ variant="2b-v2",
46
+ reauthored_model=reauthored_model,
47
+ generate_prompts=_PROMPTS.value,
48
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
49
+ max_new_tokens=_MAX_NEW_TOKENS.value,
50
+ atol=1e-04,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ app.run(main)
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions to verify the reauthored Gemma model."""
17
+
18
+ import dataclasses
19
+ import os
20
+ from typing import List, Tuple
21
+
22
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ from gemma import config as gemma_config
25
+ from gemma import model as gemma_model
26
+ import torch
27
+
28
+
29
+ @dataclasses.dataclass
30
+ class _Output:
31
+ logits: torch.Tensor
32
+
33
+
34
+ class GemmaWrapper(verifier.ModelWrapper):
35
+ """Gemma model wrapper for verification.
36
+
37
+ Verifier calls model.forward() with maxium sequence length (1024) expecting
38
+ the output has 'logits' field while Gemma gets the input tokens with the
39
+ actual length and returns logits in a tuple.
40
+
41
+ Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
42
+ inside model.generate().
43
+ """
44
+
45
+ def __init__(self, model: torch.nn.Module, max_new_tokens: int):
46
+ super().__init__(model)
47
+ self.max_new_tokens = max_new_tokens
48
+
49
+ def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
50
+ for i in range(tokens.shape[1]):
51
+ if tokens[0, i] == 0:
52
+ return i
53
+ return tokens.shape[1]
54
+
55
+ def _get_kv_caches(
56
+ self, max_seq_len: int
57
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
58
+ config = self.model.config
59
+ cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
60
+ cache = torch.zeros(cache_size)
61
+ return [
62
+ (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
63
+ ]
64
+
65
+ def forward(self, tokens: torch.Tensor) -> _Output:
66
+ """Forwards the model after reducing input tokens to the actual length."""
67
+ actual_input_len = self._get_actual_input_len(tokens)
68
+ input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
69
+ mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
70
+ _, logits = self.model.forward(
71
+ input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
72
+ input_positions=input_pos,
73
+ kv_write_indices=None,
74
+ kv_caches=self._get_kv_caches(tokens.shape[1]),
75
+ mask=mask_cache.index_select(2, input_pos),
76
+ output_positions=input_pos,
77
+ temperatures=None,
78
+ top_ps=torch.tensor([1.0], dtype=torch.float),
79
+ top_ks=torch.tensor([1], dtype=torch.long),
80
+ )
81
+ return _Output(logits.float())
82
+
83
+ def generate(self, tokens: torch.Tensor) -> torch.Tensor:
84
+ """Generates the response after decoding the tokens into a string."""
85
+ prompts = self.model.tokenizer.decode(tokens[0].tolist())
86
+ response = self.model.generate(
87
+ prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
88
+ )
89
+ return torch.tensor([self.model.tokenizer.encode(prompts + response)])
90
+
91
+
92
+ class TokenizerWrapper(torch.nn.Module):
93
+ """Tokenizer wrapper for verification.
94
+
95
+ Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
96
+ tokenizer expects tokens in a list.
97
+ """
98
+
99
+ def __init__(self, tokenizer: torch.nn.Module):
100
+ super().__init__()
101
+ self.tokenizer = tokenizer
102
+
103
+ def encode(self, text: str, **_) -> torch.Tensor:
104
+ """Adds one more dimension to the output of the tokenizer."""
105
+ return torch.tensor([self.tokenizer.encode(text)])
106
+
107
+ def decode(self, tokens: torch.Tensor) -> str:
108
+ """Decodes the token sequence after converting to a list."""
109
+ return self.tokenizer.decode(tokens.tolist())
110
+
111
+
112
+ def verify_reauthored_gemma_model(
113
+ checkpoint: str,
114
+ variant: str,
115
+ reauthored_model: torch.nn.Module,
116
+ generate_prompts: List[str],
117
+ forward_input_ids: List[List[int]],
118
+ weight_filename: str = "model.ckpt",
119
+ tokenizer_filename: str = "tokenizer.model",
120
+ max_new_tokens: int = 20,
121
+ rtol: float = 1e-05,
122
+ atol: float = 1e-05,
123
+ ):
124
+ """Verifies the reauthored Gemma model against the original model."""
125
+ config = gemma_config.get_model_config(variant)
126
+ config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
127
+ # Use float32 to be compatible with the reauthored model.
128
+ config.dtype = torch.float32
129
+
130
+ verifier.log_msg("Loading the original model from", checkpoint)
131
+ original_model = gemma_model.GemmaForCausalLM(config).eval()
132
+ original_model.load_weights(os.path.join(checkpoint, weight_filename))
133
+
134
+ verifier.verify_reauthored_model(
135
+ original_model=GemmaWrapper(original_model, max_new_tokens),
136
+ reauthored_model=reauthored_model,
137
+ tokenizer=TokenizerWrapper(original_model.tokenizer),
138
+ generate_prompts=generate_prompts,
139
+ forward_input_ids=forward_input_ids,
140
+ rtol=rtol,
141
+ atol=atol,
142
+ )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/openelm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = openelm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -33,8 +33,10 @@ _PROMPTS = flags.DEFINE_multi_string(
33
33
  def main(_):
34
34
  checkpoint = "apple/OpenELM-3B"
35
35
  verifier.log_msg("Loading the original model from", checkpoint)
36
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
- checkpoint, trust_remote_code=True
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
38
40
  )
39
41
 
40
42
  # Locate the cached dir.
@@ -50,10 +52,10 @@ def main(_):
50
52
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
51
53
 
52
54
  verifier.verify_reauthored_model(
53
- original_model=original_model,
55
+ original_model=wrapper_model,
54
56
  reauthored_model=reauthored_model,
55
57
  tokenizer=tokenizer,
56
- prompts=_PROMPTS.value,
58
+ generate_prompts=_PROMPTS.value,
57
59
  )
58
60
 
59
61
 
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/phi2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = phi2.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -24,15 +24,25 @@ import transformers
24
24
 
25
25
  _PROMPTS = flags.DEFINE_multi_string(
26
26
  "prompts",
27
- "What is the meaning of life?",
27
+ "Instruct: Write an email about the weather Output:",
28
28
  "The input prompts to generate answers.",
29
29
  )
30
30
 
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
31
36
 
32
37
  def main(_):
33
38
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
34
39
  verifier.log_msg("Loading the original model from", checkpoint)
35
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
40
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
41
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
42
+ wrapper_model = verifier.ModelWrapper(
43
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
44
+ hf_generation_config=generation_config,
45
+ )
36
46
 
37
47
  verifier.log_msg("Building the reauthored model from", checkpoint)
38
48
  reauthored_model = phi2.build_model(checkpoint)
@@ -41,10 +51,10 @@ def main(_):
41
51
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
42
52
 
43
53
  verifier.verify_reauthored_model(
44
- original_model=original_model,
54
+ original_model=wrapper_model,
45
55
  reauthored_model=reauthored_model,
46
56
  tokenizer=tokenizer,
47
- prompts=_PROMPTS.value,
57
+ generate_prompts=_PROMPTS.value,
48
58
  atol=1e-03,
49
59
  )
50
60
 
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/smollm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = smollm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -33,8 +33,9 @@ _PROMPTS = flags.DEFINE_multi_string(
33
33
  def main(_):
34
34
  checkpoint = "HuggingFaceTB/SmolLM-135M"
35
35
  verifier.log_msg("Loading the original model from", checkpoint)
36
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
37
-
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
+ )
38
39
  # Locate the cached dir.
39
40
  cached_config_file = transformers.utils.cached_file(
40
41
  checkpoint, transformers.utils.CONFIG_NAME
@@ -47,10 +48,10 @@ def main(_):
47
48
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
48
49
 
49
50
  verifier.verify_reauthored_model(
50
- original_model=original_model,
51
+ original_model=wrapper_model,
51
52
  reauthored_model=reauthored_model,
52
53
  tokenizer=tokenizer,
53
- prompts=_PROMPTS.value,
54
+ generate_prompts=_PROMPTS.value,
54
55
  atol=1e-04,
55
56
  )
56
57
 
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = tiny_llama.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -33,10 +33,11 @@ _PROMPTS = flags.DEFINE_multi_string(
33
33
  def main(_):
34
34
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
35
  verifier.log_msg("Loading the original model from", checkpoint)
36
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
- checkpoint, trust_remote_code=True
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
38
40
  )
39
-
40
41
  # Locate the cached dir.
41
42
  cached_config_file = transformers.utils.cached_file(
42
43
  checkpoint, transformers.utils.CONFIG_NAME
@@ -49,10 +50,10 @@ def main(_):
49
50
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
51
 
51
52
  verifier.verify_reauthored_model(
52
- original_model=original_model,
53
+ original_model=wrapper_model,
53
54
  reauthored_model=reauthored_model,
54
55
  tokenizer=tokenizer,
55
- prompts=_PROMPTS.value,
56
+ generate_prompts=_PROMPTS.value,
56
57
  atol=1e-04,
57
58
  )
58
59
 
@@ -18,7 +18,6 @@ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
20
  from torch import nn
21
- import torch.nn.functional as F
22
21
 
23
22
 
24
23
  class SequentialFeedForward(nn.Module):
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import ai_edge_torch
17
- from ai_edge_torch.generative.examples.gemma import gemma
17
+ from ai_edge_torch.generative.examples.gemma import gemma1
18
18
  from ai_edge_torch.generative.quantize import quant_recipes
19
19
  import numpy as np
20
20
  import torch
@@ -22,8 +22,8 @@ import torch
22
22
 
23
23
  def main():
24
24
  # Build a PyTorch model as usual
25
- config = gemma.get_fake_model_config()
26
- model = gemma.Gemma(config)
25
+ config = gemma1.get_fake_model_config()
26
+ model = gemma1.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
28
  tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
29
  tokens[0, :4] = idx
@@ -25,7 +25,7 @@ import numpy as np
25
25
  import torch
26
26
 
27
27
  from absl.testing import absltest as googletest
28
- from tensorflow.lite.python import interpreter
28
+ from ai_edge_litert import interpreter
29
29
 
30
30
 
31
31
  class TestModelConversion(googletest.TestCase):
@@ -17,7 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
23
  from ai_edge_torch.generative.examples.phi import phi2
@@ -28,7 +28,7 @@ import numpy as np
28
28
  import torch
29
29
 
30
30
  from absl.testing import absltest as googletest
31
- from tensorflow.lite.python import interpreter
31
+ from ai_edge_litert import interpreter
32
32
 
33
33
 
34
34
  class TestModelConversion(googletest.TestCase):
@@ -82,9 +82,9 @@ class TestModelConversion(googletest.TestCase):
82
82
  ai_edge_config.Config.use_torch_xla,
83
83
  reason="tests with custom ops are not supported on oss",
84
84
  )
85
- def test_gemma(self):
86
- config = gemma.get_fake_model_config()
87
- pytorch_model = gemma.Gemma(config).eval()
85
+ def test_gemma1(self):
86
+ config = gemma1.get_fake_model_config()
87
+ pytorch_model = gemma1.Gemma(config).eval()
88
88
  self._test_model(
89
89
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
90
90
  )
@@ -16,17 +16,65 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import datetime
19
- from typing import List
19
+ from typing import List, Optional, Union
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
- import numpy as np
23
22
  import torch
23
+ import transformers
24
24
 
25
25
 
26
26
  def log_msg(*args):
27
27
  print("[%s]" % datetime.datetime.now(), *args)
28
28
 
29
29
 
30
+ class ModelWrapper(torch.nn.Module):
31
+ """A wrapper for the model to be verified, this could be a HuggingFace model
32
+
33
+ or a regular PyTorch model.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: torch.nn.Module,
39
+ model_format: str = "huggingface",
40
+ hf_generation_config: Optional[transformers.GenerationConfig] = None,
41
+ ):
42
+ """Initializes the wrapper.
43
+
44
+ Args:
45
+ model (torch.nn.Module): The original model. This could be a model built
46
+ from HuggingFace transformers, or a regular PyTorch model.
47
+ model_format (str): The format of the model. It should be either
48
+ "huggingface" or "pytorch".
49
+ hf_generation_config (transformers.GenerationConfig): The HuggingFace
50
+ generation config. This config will only be used if the underlying model
51
+ is built from HuggingFace transformers.
52
+ """
53
+ super().__init__()
54
+ self.model = model
55
+ self.model_format = model_format
56
+ self.hf_generation_config = hf_generation_config
57
+
58
+ def generate(
59
+ self, inputs: torch.Tensor
60
+ ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
61
+ if self.model_format == "huggingface":
62
+ return self.model.generate(
63
+ inputs=inputs, generation_config=self.hf_generation_config
64
+ )
65
+ else:
66
+ raise NotImplementedError(
67
+ "generate() is not implemented for model format: %s"
68
+ % self.model_format
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ inputs: torch.Tensor,
74
+ ):
75
+ return self.model.forward(inputs)
76
+
77
+
30
78
  def forward(
31
79
  model: torch.nn.Module,
32
80
  tokens: torch.Tensor,
@@ -75,9 +123,9 @@ def generate(
75
123
 
76
124
 
77
125
  def verify_with_input_ids(
78
- original_model: torch.nn.Module,
126
+ original_model: ModelWrapper,
79
127
  reauthored_model: torch.nn.Module,
80
- input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
128
+ input_ids: List[int],
81
129
  kv_cache_max_len: int = 1024,
82
130
  rtol: float = 1e-05,
83
131
  atol: float = 1e-05,
@@ -87,10 +135,10 @@ def verify_with_input_ids(
87
135
  It compares only one outputs from the original and the reauthored model.
88
136
 
89
137
  Args:
90
- original_model (torch.nn.Module): The original model.
138
+ original_model (ModelWrapper): The original model.
91
139
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
92
140
  Generative API.
93
- input_ids (torch.Tensor): The input token IDs to forward.
141
+ input_ids (List[int]): The input token IDs to forward with.
94
142
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
95
143
  rtol (float): The relative tolerance for the comparison.
96
144
  atol (float): The absolute tolerance for the comparison.
@@ -99,18 +147,17 @@ def verify_with_input_ids(
99
147
  True if the model reauthored generates the same output of the original.
100
148
  """
101
149
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
102
- input_ids_len = input_ids.shape[1]
103
- tokens[0, :input_ids_len] = input_ids
150
+ tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
104
151
 
105
152
  log_msg("Forwarding the original model...")
106
153
  outputs_original = original_model.forward(tokens)
107
- logits_original = outputs_original.logits[0, input_ids_len - 1, :]
154
+ logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
108
155
  log_msg("logits_original: ", logits_original)
109
156
 
110
157
  log_msg("Forwarding the reauthored model...")
111
158
  kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
112
159
  outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
113
- logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
160
+ logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
114
161
  log_msg("logits_reauthored:", logits_reauthored)
115
162
 
116
163
  return torch.allclose(
@@ -119,7 +166,7 @@ def verify_with_input_ids(
119
166
 
120
167
 
121
168
  def verify_model_with_prompts(
122
- original_model: torch.nn.Module,
169
+ original_model: ModelWrapper,
123
170
  reauthored_model: torch.nn.Module,
124
171
  tokenizer: torch.nn.Module,
125
172
  prompts: str,
@@ -130,7 +177,7 @@ def verify_model_with_prompts(
130
177
  original and the reauthored model.
131
178
 
132
179
  Args:
133
- original_model (torch.nn.Module): The original model.
180
+ original_model (ModelWrapper): The original model.
134
181
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
135
182
  Generative API.
136
183
  tokenizer (torch.nn.Module): The tokenizer.
@@ -156,10 +203,11 @@ def verify_model_with_prompts(
156
203
 
157
204
 
158
205
  def verify_reauthored_model(
159
- original_model: torch.nn.Module,
206
+ original_model: ModelWrapper,
160
207
  reauthored_model: torch.nn.Module,
161
208
  tokenizer: torch.nn.Module,
162
- prompts: List[str],
209
+ generate_prompts: List[str],
210
+ forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
163
211
  rtol: float = 1e-05,
164
212
  atol: float = 1e-05,
165
213
  ):
@@ -174,26 +222,29 @@ def verify_reauthored_model(
174
222
  It prints out "PASS" or "FAILED" to the console.
175
223
 
176
224
  Args:
177
- original_model (torch.nn.Module): The original model.
225
+ original_model (ModelWrapper): The original model.
178
226
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
179
227
  Generative API.
180
228
  tokenizer (torch.nn.Module): The tokenizer.
181
- prompts (List[str]): List of the input prompts to generate answers.
229
+ generate_prompts (List[str]): List of the input prompts to generate answers.
230
+ forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
231
+ forward with.
182
232
  rtol (float): The relative tolerance for the comparison.
183
233
  atol (float): The absolute tolerance for the comparison.
184
234
  """
185
- log_msg("Verifying the reauthored model with an arbitrary input...")
186
- if verify_with_input_ids(
187
- original_model, reauthored_model, rtol=rtol, atol=atol
188
- ):
189
- log_msg("PASS")
190
- else:
191
- log_msg("FAILED")
235
+ for input_ids in forward_input_ids:
236
+ log_msg("Verifying the reauthored model with input IDs:", input_ids)
237
+ if verify_with_input_ids(
238
+ original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
239
+ ):
240
+ log_msg("PASS")
241
+ else:
242
+ log_msg("FAILED")
192
243
 
193
- for p in prompts:
194
- log_msg("Verifying the reauthored model with prompts:", p)
244
+ for prompts in generate_prompts:
245
+ log_msg("Verifying the reauthored model with prompts:", prompts)
195
246
  if verify_model_with_prompts(
196
- original_model, reauthored_model, tokenizer, p
247
+ original_model, reauthored_model, tokenizer, prompts
197
248
  ):
198
249
  log_msg("PASS")
199
250
  else:
ai_edge_torch/model.py CHANGED
@@ -27,6 +27,8 @@ from typing import Callable
27
27
  import numpy.typing as npt
28
28
  import tensorflow as tf
29
29
 
30
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
31
+
30
32
  DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
31
33
 
32
34
 
@@ -65,7 +67,7 @@ class TfLiteModel(Model):
65
67
  tflite_model: A TFlite serialized object.
66
68
  """
67
69
  self._tflite_model = tflite_model
68
- self._interpreter_builder = lambda: tf.lite.Interpreter(
70
+ self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
69
71
  model_content=self._tflite_model,
70
72
  experimental_default_delegate_latest_features=True,
71
73
  )
@@ -75,12 +77,13 @@ class TfLiteModel(Model):
75
77
  return self._tflite_model
76
78
 
77
79
  def set_interpreter_builder(
78
- self, builder: Callable[[], tf.lite.Interpreter]
80
+ self, builder: Callable[[], tfl_interpreter.Interpreter]
79
81
  ) -> None:
80
82
  """Sets a custom interpreter builder.
81
83
 
82
84
  Args:
83
- builder: A function that returns a `tf.lite.Interpreter` or its subclass.
85
+ builder: A function that returns a `tfl_interpreter.Interpreter` or its
86
+ subclass.
84
87
  """
85
88
  self._interpreter_builder = builder
86
89
 
@@ -166,7 +169,7 @@ class TfLiteModel(Model):
166
169
 
167
170
  # Check if this is indeed a tflite model:
168
171
  try:
169
- interpreter = tf.lite.Interpreter(model_content=model_content)
172
+ interpreter = tfl_interpreter.Interpreter(model_content=model_content)
170
173
  interpreter.get_signature_list()
171
174
  except:
172
175
  return None
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240919"
16
+ __version__ = "0.3.0.dev20240921"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240919
3
+ Version: 0.3.0.dev20240921
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,8 +2,8 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
- ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
5
+ ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
+ ai_edge_torch/version.py,sha256=t9zajdsiowClI2fG0RkKVonPF-SUx9UBuUDOEZFU9y4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -25,7 +25,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
25
25
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
27
27
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
28
- ai_edge_torch/_convert/test/test_convert.py,sha256=FSufFZEeTLBpUnzE1Iy-LvNN0mhDynWMNg7Mei8RpLQ,14973
28
+ ai_edge_torch/_convert/test/test_convert.py,sha256=40QRxQFNeSRr4dLXJkzG-wKUlvJtsfv62cdvRrmBv5w,15097
29
29
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
30
30
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
31
31
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -39,22 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
39
39
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
43
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
44
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
42
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=KsM6VlzluTqbodG24IFr3biPxBrLay0z0gmnG0bcU2U,9277
46
+ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=B14IR4mIw6qBVUbiIRdfdUzHMCIJCJ0RFPsYOxA46qc,1776
47
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax1m4WTAnC3gMxBpcvGA7-xTO1Iuw,1802
48
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
46
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
50
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
48
51
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
49
- ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
52
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
50
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
54
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
52
55
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
53
- ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
56
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
54
57
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
58
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
56
59
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
57
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
60
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
58
61
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
62
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
60
63
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
@@ -78,16 +81,16 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
78
81
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
79
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
80
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
84
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
82
85
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
83
- ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
86
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LUChL5tA7FHL_DlTg5QKvGInmH9AwVVw9a-omcndiz8,2095
84
87
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
85
88
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
86
89
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
90
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
88
91
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
89
92
  ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
90
- ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
93
+ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
91
94
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
92
95
  ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
93
96
  ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
@@ -98,7 +101,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW
98
101
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
99
102
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
100
103
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
101
- ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
104
+ ai_edge_torch/generative/quantize/example.py,sha256=tlACaRsz6lqOxakzpXVFJZYfFKOiFqetcYVJqWVRdPE,1542
102
105
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
103
106
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
104
107
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
@@ -107,8 +110,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
107
110
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
111
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
109
112
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
110
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
111
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
113
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
114
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
112
115
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
113
116
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
114
117
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -116,7 +119,7 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
116
119
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
117
120
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
118
121
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
119
- ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
122
+ ai_edge_torch/generative/utilities/verifier.py,sha256=vU9KgmFS7I9jNS_3H2SWROx-rbNqtMKgQC2MRhdqQ4g,8803
120
123
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
121
124
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
122
125
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -163,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
163
166
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
164
167
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
165
168
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
166
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
167
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
168
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
169
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
170
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,
169
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/METADATA,sha256=SWy7BhOQDe0_SBF17deNndzt1bEYy7iXUxy0KznIPYM,1859
171
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/RECORD,,