ai-edge-torch-nightly 0.3.0.dev20240919__py3-none-any.whl → 0.3.0.dev20240921__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (28) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +7 -3
  2. ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py} +9 -7
  3. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
  4. ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +3 -36
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -26
  6. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +55 -0
  7. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +55 -0
  8. ai_edge_torch/generative/examples/gemma/verify_util.py +142 -0
  9. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
  10. ai_edge_torch/generative/examples/openelm/verify.py +6 -4
  11. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
  12. ai_edge_torch/generative/examples/phi/verify.py +14 -4
  13. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
  14. ai_edge_torch/generative/examples/smollm/verify.py +5 -4
  15. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
  16. ai_edge_torch/generative/examples/tiny_llama/verify.py +6 -5
  17. ai_edge_torch/generative/layers/feed_forward.py +0 -1
  18. ai_edge_torch/generative/quantize/example.py +3 -3
  19. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  20. ai_edge_torch/generative/test/test_model_conversion_large.py +5 -5
  21. ai_edge_torch/generative/utilities/verifier.py +77 -26
  22. ai_edge_torch/model.py +7 -4
  23. ai_edge_torch/version.py +1 -1
  24. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/METADATA +1 -1
  25. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/RECORD +28 -25
  26. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/LICENSE +0 -0
  27. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/WHEEL +0 -0
  28. {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
23
23
  from ai_edge_torch._convert import conversion_utils
24
24
  from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
- import tensorflow as tf
27
26
  import torch
28
27
  from torch import nn
29
28
  import torchvision
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
32
32
 
33
33
 
34
34
  @dataclasses.dataclass
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
466
466
  np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
467
467
  np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
468
468
 
469
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
469
+ interpreter = tfl_interpreter.Interpreter(
470
+ model_content=edge_model._tflite_model
471
+ )
470
472
  runner = interpreter.get_signature_runner("serving_default")
471
473
  output_details = runner.get_output_details()
472
474
  self.assertIn("x", output_details.keys())
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
477
479
  def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
478
480
  model.eval()
479
481
  edge_model = ai_edge_torch.convert(model, args, kwargs)
480
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
482
+ interpreter = tfl_interpreter.Interpreter(
483
+ model_content=edge_model._tflite_model
484
+ )
481
485
  runner = interpreter.get_signature_runner("serving_default")
482
486
  input_details = runner.get_input_details()
483
487
  self.assertEqual(input_details.keys(), flat_inputs.keys())
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of converting a Gemma model to multi-signature tflite model."""
16
+ """Example of converting a Gemma1 model to multi-signature tflite model."""
17
17
 
18
18
  import os
19
19
  import pathlib
20
20
 
21
21
  from absl import app
22
22
  from absl import flags
23
- from ai_edge_torch.generative.examples.gemma import gemma
23
+ from ai_edge_torch.generative.examples.gemma import gemma1
24
24
  from ai_edge_torch.generative.utilities import converter
25
25
 
26
26
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -51,12 +51,14 @@ _QUANTIZE = flags.DEFINE_bool(
51
51
 
52
52
 
53
53
  def main(_):
54
- pytorch_model = gemma.build_2b_model(
54
+ pytorch_model = gemma1.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = gemma2.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -13,10 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- """Example of building a Gemma model."""
17
-
18
- import os
19
- import pathlib
16
+ """Example of building a Gemma1 model."""
20
17
 
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
@@ -24,7 +21,6 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -32,13 +28,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
28
  ff_up_proj="model.layers.{}.mlp.up_proj",
33
29
  ff_down_proj="model.layers.{}.mlp.down_proj",
34
30
  ff_gate_proj="model.layers.{}.mlp.gate_proj",
35
- attn_query_proj="model.layers.{}.self_attn.q_proj",
36
- attn_key_proj="model.layers.{}.self_attn.k_proj",
37
- attn_value_proj="model.layers.{}.self_attn.v_proj",
31
+ attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
38
32
  attn_output_proj="model.layers.{}.self_attn.o_proj",
39
33
  pre_attn_norm="model.layers.{}.input_layernorm",
40
34
  post_attn_norm="model.layers.{}.post_attention_layernorm",
41
- embedding="model.embed_tokens",
35
+ embedding="embedder",
42
36
  final_norm="model.norm",
43
37
  lm_head=None,
44
38
  )
@@ -192,30 +186,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
192
186
  loader.load(model, strict=False)
193
187
  model.eval()
194
188
  return model
195
-
196
-
197
- def define_and_run_2b(checkpoint_path: str) -> None:
198
- """Instantiates and runs a Gemma 2B model."""
199
-
200
- current_dir = pathlib.Path(__file__).parent.resolve()
201
- gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
202
-
203
- kv_cache_max_len = 1024
204
- model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
205
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
206
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
207
- tokens[0, :4] = idx
208
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
209
- kv = kv_utils.KVCache.from_model_config(model.config)
210
- output = model.forward(tokens, input_pos, kv)
211
- print("comparing with goldens..")
212
- assert torch.allclose(
213
- gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
214
- )
215
-
216
-
217
- if __name__ == "__main__":
218
- input_checkpoint_path = os.path.join(
219
- pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
220
- )
221
- define_and_run_2b(input_checkpoint_path)
@@ -267,29 +267,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
267
267
  loader.load(model, strict=False)
268
268
  model.eval()
269
269
  return model
270
-
271
-
272
- def define_and_run_2b(checkpoint_path: str) -> None:
273
- """Instantiates and runs a Gemma2 2B model."""
274
-
275
- current_dir = pathlib.Path(__file__).parent.resolve()
276
- gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
277
- print("Running GEMMA 2")
278
- kv_cache_max_len = 1024
279
- model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
280
- toks = torch.from_numpy(
281
- np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
282
- )
283
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
284
- tokens[0, :9] = toks
285
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
286
- kv = kv_utils.KVCache.from_model_config(model.config)
287
- out = model.forward(tokens, input_pos, kv)
288
- out_final = out["logits"][0, 8, :]
289
- assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
290
-
291
-
292
- if __name__ == "__main__":
293
- torch.set_printoptions(sci_mode=True)
294
- path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
295
- define_and_run_2b(path)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma1 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
+ from ai_edge_torch.generative.examples.gemma import verify_util
22
+ from ai_edge_torch.generative.utilities import verifier
23
+ import kagglehub
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
+ "max_new_tokens",
32
+ 30,
33
+ "The maximum size of the generated tokens.",
34
+ )
35
+
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
39
+
40
+ verifier.log_msg("Building the reauthored model from", checkpoint)
41
+ reauthored_model = gemma1.build_2b_model(checkpoint)
42
+
43
+ verify_util.verify_reauthored_gemma_model(
44
+ checkpoint=checkpoint,
45
+ variant="2b",
46
+ reauthored_model=reauthored_model,
47
+ weight_filename="gemma-2b-it.ckpt",
48
+ generate_prompts=_PROMPTS.value,
49
+ forward_input_ids=[[1, 2, 3, 4]],
50
+ max_new_tokens=_MAX_NEW_TOKENS.value,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ app.run(main)
@@ -0,0 +1,55 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Gemma2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.gemma import gemma2
21
+ from ai_edge_torch.generative.examples.gemma import verify_util
22
+ from ai_edge_torch.generative.utilities import verifier
23
+ import kagglehub
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "What is the meaning of life?",
28
+ "The input prompts to generate answers.",
29
+ )
30
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
31
+ "max_new_tokens",
32
+ 30,
33
+ "The maximum size of the generated tokens.",
34
+ )
35
+
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
39
+
40
+ verifier.log_msg("Building the reauthored model from", checkpoint)
41
+ reauthored_model = gemma2.build_2b_model(checkpoint)
42
+
43
+ verify_util.verify_reauthored_gemma_model(
44
+ checkpoint=checkpoint,
45
+ variant="2b-v2",
46
+ reauthored_model=reauthored_model,
47
+ generate_prompts=_PROMPTS.value,
48
+ forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
49
+ max_new_tokens=_MAX_NEW_TOKENS.value,
50
+ atol=1e-04,
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ app.run(main)
@@ -0,0 +1,142 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utility functions to verify the reauthored Gemma model."""
17
+
18
+ import dataclasses
19
+ import os
20
+ from typing import List, Tuple
21
+
22
+ import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ from gemma import config as gemma_config
25
+ from gemma import model as gemma_model
26
+ import torch
27
+
28
+
29
+ @dataclasses.dataclass
30
+ class _Output:
31
+ logits: torch.Tensor
32
+
33
+
34
+ class GemmaWrapper(verifier.ModelWrapper):
35
+ """Gemma model wrapper for verification.
36
+
37
+ Verifier calls model.forward() with maxium sequence length (1024) expecting
38
+ the output has 'logits' field while Gemma gets the input tokens with the
39
+ actual length and returns logits in a tuple.
40
+
41
+ Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
42
+ inside model.generate().
43
+ """
44
+
45
+ def __init__(self, model: torch.nn.Module, max_new_tokens: int):
46
+ super().__init__(model)
47
+ self.max_new_tokens = max_new_tokens
48
+
49
+ def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
50
+ for i in range(tokens.shape[1]):
51
+ if tokens[0, i] == 0:
52
+ return i
53
+ return tokens.shape[1]
54
+
55
+ def _get_kv_caches(
56
+ self, max_seq_len: int
57
+ ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
58
+ config = self.model.config
59
+ cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
60
+ cache = torch.zeros(cache_size)
61
+ return [
62
+ (cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
63
+ ]
64
+
65
+ def forward(self, tokens: torch.Tensor) -> _Output:
66
+ """Forwards the model after reducing input tokens to the actual length."""
67
+ actual_input_len = self._get_actual_input_len(tokens)
68
+ input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
69
+ mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
70
+ _, logits = self.model.forward(
71
+ input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
72
+ input_positions=input_pos,
73
+ kv_write_indices=None,
74
+ kv_caches=self._get_kv_caches(tokens.shape[1]),
75
+ mask=mask_cache.index_select(2, input_pos),
76
+ output_positions=input_pos,
77
+ temperatures=None,
78
+ top_ps=torch.tensor([1.0], dtype=torch.float),
79
+ top_ks=torch.tensor([1], dtype=torch.long),
80
+ )
81
+ return _Output(logits.float())
82
+
83
+ def generate(self, tokens: torch.Tensor) -> torch.Tensor:
84
+ """Generates the response after decoding the tokens into a string."""
85
+ prompts = self.model.tokenizer.decode(tokens[0].tolist())
86
+ response = self.model.generate(
87
+ prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
88
+ )
89
+ return torch.tensor([self.model.tokenizer.encode(prompts + response)])
90
+
91
+
92
+ class TokenizerWrapper(torch.nn.Module):
93
+ """Tokenizer wrapper for verification.
94
+
95
+ Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
96
+ tokenizer expects tokens in a list.
97
+ """
98
+
99
+ def __init__(self, tokenizer: torch.nn.Module):
100
+ super().__init__()
101
+ self.tokenizer = tokenizer
102
+
103
+ def encode(self, text: str, **_) -> torch.Tensor:
104
+ """Adds one more dimension to the output of the tokenizer."""
105
+ return torch.tensor([self.tokenizer.encode(text)])
106
+
107
+ def decode(self, tokens: torch.Tensor) -> str:
108
+ """Decodes the token sequence after converting to a list."""
109
+ return self.tokenizer.decode(tokens.tolist())
110
+
111
+
112
+ def verify_reauthored_gemma_model(
113
+ checkpoint: str,
114
+ variant: str,
115
+ reauthored_model: torch.nn.Module,
116
+ generate_prompts: List[str],
117
+ forward_input_ids: List[List[int]],
118
+ weight_filename: str = "model.ckpt",
119
+ tokenizer_filename: str = "tokenizer.model",
120
+ max_new_tokens: int = 20,
121
+ rtol: float = 1e-05,
122
+ atol: float = 1e-05,
123
+ ):
124
+ """Verifies the reauthored Gemma model against the original model."""
125
+ config = gemma_config.get_model_config(variant)
126
+ config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
127
+ # Use float32 to be compatible with the reauthored model.
128
+ config.dtype = torch.float32
129
+
130
+ verifier.log_msg("Loading the original model from", checkpoint)
131
+ original_model = gemma_model.GemmaForCausalLM(config).eval()
132
+ original_model.load_weights(os.path.join(checkpoint, weight_filename))
133
+
134
+ verifier.verify_reauthored_model(
135
+ original_model=GemmaWrapper(original_model, max_new_tokens),
136
+ reauthored_model=reauthored_model,
137
+ tokenizer=TokenizerWrapper(original_model.tokenizer),
138
+ generate_prompts=generate_prompts,
139
+ forward_input_ids=forward_input_ids,
140
+ rtol=rtol,
141
+ atol=atol,
142
+ )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/openelm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = openelm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -33,8 +33,10 @@ _PROMPTS = flags.DEFINE_multi_string(
33
33
  def main(_):
34
34
  checkpoint = "apple/OpenELM-3B"
35
35
  verifier.log_msg("Loading the original model from", checkpoint)
36
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
- checkpoint, trust_remote_code=True
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
38
40
  )
39
41
 
40
42
  # Locate the cached dir.
@@ -50,10 +52,10 @@ def main(_):
50
52
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
51
53
 
52
54
  verifier.verify_reauthored_model(
53
- original_model=original_model,
55
+ original_model=wrapper_model,
54
56
  reauthored_model=reauthored_model,
55
57
  tokenizer=tokenizer,
56
- prompts=_PROMPTS.value,
58
+ generate_prompts=_PROMPTS.value,
57
59
  )
58
60
 
59
61
 
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/phi2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = phi2.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -24,15 +24,25 @@ import transformers
24
24
 
25
25
  _PROMPTS = flags.DEFINE_multi_string(
26
26
  "prompts",
27
- "What is the meaning of life?",
27
+ "Instruct: Write an email about the weather Output:",
28
28
  "The input prompts to generate answers.",
29
29
  )
30
30
 
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
31
36
 
32
37
  def main(_):
33
38
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
34
39
  verifier.log_msg("Loading the original model from", checkpoint)
35
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
40
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
41
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
42
+ wrapper_model = verifier.ModelWrapper(
43
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
44
+ hf_generation_config=generation_config,
45
+ )
36
46
 
37
47
  verifier.log_msg("Building the reauthored model from", checkpoint)
38
48
  reauthored_model = phi2.build_model(checkpoint)
@@ -41,10 +51,10 @@ def main(_):
41
51
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
42
52
 
43
53
  verifier.verify_reauthored_model(
44
- original_model=original_model,
54
+ original_model=wrapper_model,
45
55
  reauthored_model=reauthored_model,
46
56
  tokenizer=tokenizer,
47
- prompts=_PROMPTS.value,
57
+ generate_prompts=_PROMPTS.value,
48
58
  atol=1e-03,
49
59
  )
50
60
 
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/smollm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = smollm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -33,8 +33,9 @@ _PROMPTS = flags.DEFINE_multi_string(
33
33
  def main(_):
34
34
  checkpoint = "HuggingFaceTB/SmolLM-135M"
35
35
  verifier.log_msg("Loading the original model from", checkpoint)
36
- original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
37
-
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
+ )
38
39
  # Locate the cached dir.
39
40
  cached_config_file = transformers.utils.cached_file(
40
41
  checkpoint, transformers.utils.CONFIG_NAME
@@ -47,10 +48,10 @@ def main(_):
47
48
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
48
49
 
49
50
  verifier.verify_reauthored_model(
50
- original_model=original_model,
51
+ original_model=wrapper_model,
51
52
  reauthored_model=reauthored_model,
52
53
  tokenizer=tokenizer,
53
- prompts=_PROMPTS.value,
54
+ generate_prompts=_PROMPTS.value,
54
55
  atol=1e-04,
55
56
  )
56
57
 
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = tiny_llama.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -33,10 +33,11 @@ _PROMPTS = flags.DEFINE_multi_string(
33
33
  def main(_):
34
34
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
35
  verifier.log_msg("Loading the original model from", checkpoint)
36
- original_model = transformers.AutoModelForCausalLM.from_pretrained(
37
- checkpoint, trust_remote_code=True
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
38
40
  )
39
-
40
41
  # Locate the cached dir.
41
42
  cached_config_file = transformers.utils.cached_file(
42
43
  checkpoint, transformers.utils.CONFIG_NAME
@@ -49,10 +50,10 @@ def main(_):
49
50
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
50
51
 
51
52
  verifier.verify_reauthored_model(
52
- original_model=original_model,
53
+ original_model=wrapper_model,
53
54
  reauthored_model=reauthored_model,
54
55
  tokenizer=tokenizer,
55
- prompts=_PROMPTS.value,
56
+ generate_prompts=_PROMPTS.value,
56
57
  atol=1e-04,
57
58
  )
58
59
 
@@ -18,7 +18,6 @@ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
20
  from torch import nn
21
- import torch.nn.functional as F
22
21
 
23
22
 
24
23
  class SequentialFeedForward(nn.Module):
@@ -14,7 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  import ai_edge_torch
17
- from ai_edge_torch.generative.examples.gemma import gemma
17
+ from ai_edge_torch.generative.examples.gemma import gemma1
18
18
  from ai_edge_torch.generative.quantize import quant_recipes
19
19
  import numpy as np
20
20
  import torch
@@ -22,8 +22,8 @@ import torch
22
22
 
23
23
  def main():
24
24
  # Build a PyTorch model as usual
25
- config = gemma.get_fake_model_config()
26
- model = gemma.Gemma(config)
25
+ config = gemma1.get_fake_model_config()
26
+ model = gemma1.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
28
  tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
29
  tokens[0, :4] = idx
@@ -25,7 +25,7 @@ import numpy as np
25
25
  import torch
26
26
 
27
27
  from absl.testing import absltest as googletest
28
- from tensorflow.lite.python import interpreter
28
+ from ai_edge_litert import interpreter
29
29
 
30
30
 
31
31
  class TestModelConversion(googletest.TestCase):
@@ -17,7 +17,7 @@
17
17
 
18
18
  import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
- from ai_edge_torch.generative.examples.gemma import gemma
20
+ from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
23
  from ai_edge_torch.generative.examples.phi import phi2
@@ -28,7 +28,7 @@ import numpy as np
28
28
  import torch
29
29
 
30
30
  from absl.testing import absltest as googletest
31
- from tensorflow.lite.python import interpreter
31
+ from ai_edge_litert import interpreter
32
32
 
33
33
 
34
34
  class TestModelConversion(googletest.TestCase):
@@ -82,9 +82,9 @@ class TestModelConversion(googletest.TestCase):
82
82
  ai_edge_config.Config.use_torch_xla,
83
83
  reason="tests with custom ops are not supported on oss",
84
84
  )
85
- def test_gemma(self):
86
- config = gemma.get_fake_model_config()
87
- pytorch_model = gemma.Gemma(config).eval()
85
+ def test_gemma1(self):
86
+ config = gemma1.get_fake_model_config()
87
+ pytorch_model = gemma1.Gemma(config).eval()
88
88
  self._test_model(
89
89
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
90
90
  )
@@ -16,17 +16,65 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import datetime
19
- from typing import List
19
+ from typing import List, Optional, Union
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
- import numpy as np
23
22
  import torch
23
+ import transformers
24
24
 
25
25
 
26
26
  def log_msg(*args):
27
27
  print("[%s]" % datetime.datetime.now(), *args)
28
28
 
29
29
 
30
+ class ModelWrapper(torch.nn.Module):
31
+ """A wrapper for the model to be verified, this could be a HuggingFace model
32
+
33
+ or a regular PyTorch model.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model: torch.nn.Module,
39
+ model_format: str = "huggingface",
40
+ hf_generation_config: Optional[transformers.GenerationConfig] = None,
41
+ ):
42
+ """Initializes the wrapper.
43
+
44
+ Args:
45
+ model (torch.nn.Module): The original model. This could be a model built
46
+ from HuggingFace transformers, or a regular PyTorch model.
47
+ model_format (str): The format of the model. It should be either
48
+ "huggingface" or "pytorch".
49
+ hf_generation_config (transformers.GenerationConfig): The HuggingFace
50
+ generation config. This config will only be used if the underlying model
51
+ is built from HuggingFace transformers.
52
+ """
53
+ super().__init__()
54
+ self.model = model
55
+ self.model_format = model_format
56
+ self.hf_generation_config = hf_generation_config
57
+
58
+ def generate(
59
+ self, inputs: torch.Tensor
60
+ ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
61
+ if self.model_format == "huggingface":
62
+ return self.model.generate(
63
+ inputs=inputs, generation_config=self.hf_generation_config
64
+ )
65
+ else:
66
+ raise NotImplementedError(
67
+ "generate() is not implemented for model format: %s"
68
+ % self.model_format
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ inputs: torch.Tensor,
74
+ ):
75
+ return self.model.forward(inputs)
76
+
77
+
30
78
  def forward(
31
79
  model: torch.nn.Module,
32
80
  tokens: torch.Tensor,
@@ -75,9 +123,9 @@ def generate(
75
123
 
76
124
 
77
125
  def verify_with_input_ids(
78
- original_model: torch.nn.Module,
126
+ original_model: ModelWrapper,
79
127
  reauthored_model: torch.nn.Module,
80
- input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
128
+ input_ids: List[int],
81
129
  kv_cache_max_len: int = 1024,
82
130
  rtol: float = 1e-05,
83
131
  atol: float = 1e-05,
@@ -87,10 +135,10 @@ def verify_with_input_ids(
87
135
  It compares only one outputs from the original and the reauthored model.
88
136
 
89
137
  Args:
90
- original_model (torch.nn.Module): The original model.
138
+ original_model (ModelWrapper): The original model.
91
139
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
92
140
  Generative API.
93
- input_ids (torch.Tensor): The input token IDs to forward.
141
+ input_ids (List[int]): The input token IDs to forward with.
94
142
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
95
143
  rtol (float): The relative tolerance for the comparison.
96
144
  atol (float): The absolute tolerance for the comparison.
@@ -99,18 +147,17 @@ def verify_with_input_ids(
99
147
  True if the model reauthored generates the same output of the original.
100
148
  """
101
149
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
102
- input_ids_len = input_ids.shape[1]
103
- tokens[0, :input_ids_len] = input_ids
150
+ tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
104
151
 
105
152
  log_msg("Forwarding the original model...")
106
153
  outputs_original = original_model.forward(tokens)
107
- logits_original = outputs_original.logits[0, input_ids_len - 1, :]
154
+ logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
108
155
  log_msg("logits_original: ", logits_original)
109
156
 
110
157
  log_msg("Forwarding the reauthored model...")
111
158
  kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
112
159
  outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
113
- logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
160
+ logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
114
161
  log_msg("logits_reauthored:", logits_reauthored)
115
162
 
116
163
  return torch.allclose(
@@ -119,7 +166,7 @@ def verify_with_input_ids(
119
166
 
120
167
 
121
168
  def verify_model_with_prompts(
122
- original_model: torch.nn.Module,
169
+ original_model: ModelWrapper,
123
170
  reauthored_model: torch.nn.Module,
124
171
  tokenizer: torch.nn.Module,
125
172
  prompts: str,
@@ -130,7 +177,7 @@ def verify_model_with_prompts(
130
177
  original and the reauthored model.
131
178
 
132
179
  Args:
133
- original_model (torch.nn.Module): The original model.
180
+ original_model (ModelWrapper): The original model.
134
181
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
135
182
  Generative API.
136
183
  tokenizer (torch.nn.Module): The tokenizer.
@@ -156,10 +203,11 @@ def verify_model_with_prompts(
156
203
 
157
204
 
158
205
  def verify_reauthored_model(
159
- original_model: torch.nn.Module,
206
+ original_model: ModelWrapper,
160
207
  reauthored_model: torch.nn.Module,
161
208
  tokenizer: torch.nn.Module,
162
- prompts: List[str],
209
+ generate_prompts: List[str],
210
+ forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
163
211
  rtol: float = 1e-05,
164
212
  atol: float = 1e-05,
165
213
  ):
@@ -174,26 +222,29 @@ def verify_reauthored_model(
174
222
  It prints out "PASS" or "FAILED" to the console.
175
223
 
176
224
  Args:
177
- original_model (torch.nn.Module): The original model.
225
+ original_model (ModelWrapper): The original model.
178
226
  reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
179
227
  Generative API.
180
228
  tokenizer (torch.nn.Module): The tokenizer.
181
- prompts (List[str]): List of the input prompts to generate answers.
229
+ generate_prompts (List[str]): List of the input prompts to generate answers.
230
+ forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
231
+ forward with.
182
232
  rtol (float): The relative tolerance for the comparison.
183
233
  atol (float): The absolute tolerance for the comparison.
184
234
  """
185
- log_msg("Verifying the reauthored model with an arbitrary input...")
186
- if verify_with_input_ids(
187
- original_model, reauthored_model, rtol=rtol, atol=atol
188
- ):
189
- log_msg("PASS")
190
- else:
191
- log_msg("FAILED")
235
+ for input_ids in forward_input_ids:
236
+ log_msg("Verifying the reauthored model with input IDs:", input_ids)
237
+ if verify_with_input_ids(
238
+ original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
239
+ ):
240
+ log_msg("PASS")
241
+ else:
242
+ log_msg("FAILED")
192
243
 
193
- for p in prompts:
194
- log_msg("Verifying the reauthored model with prompts:", p)
244
+ for prompts in generate_prompts:
245
+ log_msg("Verifying the reauthored model with prompts:", prompts)
195
246
  if verify_model_with_prompts(
196
- original_model, reauthored_model, tokenizer, p
247
+ original_model, reauthored_model, tokenizer, prompts
197
248
  ):
198
249
  log_msg("PASS")
199
250
  else:
ai_edge_torch/model.py CHANGED
@@ -27,6 +27,8 @@ from typing import Callable
27
27
  import numpy.typing as npt
28
28
  import tensorflow as tf
29
29
 
30
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
31
+
30
32
  DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
31
33
 
32
34
 
@@ -65,7 +67,7 @@ class TfLiteModel(Model):
65
67
  tflite_model: A TFlite serialized object.
66
68
  """
67
69
  self._tflite_model = tflite_model
68
- self._interpreter_builder = lambda: tf.lite.Interpreter(
70
+ self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
69
71
  model_content=self._tflite_model,
70
72
  experimental_default_delegate_latest_features=True,
71
73
  )
@@ -75,12 +77,13 @@ class TfLiteModel(Model):
75
77
  return self._tflite_model
76
78
 
77
79
  def set_interpreter_builder(
78
- self, builder: Callable[[], tf.lite.Interpreter]
80
+ self, builder: Callable[[], tfl_interpreter.Interpreter]
79
81
  ) -> None:
80
82
  """Sets a custom interpreter builder.
81
83
 
82
84
  Args:
83
- builder: A function that returns a `tf.lite.Interpreter` or its subclass.
85
+ builder: A function that returns a `tfl_interpreter.Interpreter` or its
86
+ subclass.
84
87
  """
85
88
  self._interpreter_builder = builder
86
89
 
@@ -166,7 +169,7 @@ class TfLiteModel(Model):
166
169
 
167
170
  # Check if this is indeed a tflite model:
168
171
  try:
169
- interpreter = tf.lite.Interpreter(model_content=model_content)
172
+ interpreter = tfl_interpreter.Interpreter(model_content=model_content)
170
173
  interpreter.get_signature_list()
171
174
  except:
172
175
  return None
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240919"
16
+ __version__ = "0.3.0.dev20240921"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240919
3
+ Version: 0.3.0.dev20240921
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,8 +2,8 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
- ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
5
+ ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
+ ai_edge_torch/version.py,sha256=t9zajdsiowClI2fG0RkKVonPF-SUx9UBuUDOEZFU9y4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -25,7 +25,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
25
25
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
27
27
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
28
- ai_edge_torch/_convert/test/test_convert.py,sha256=FSufFZEeTLBpUnzE1Iy-LvNN0mhDynWMNg7Mei8RpLQ,14973
28
+ ai_edge_torch/_convert/test/test_convert.py,sha256=40QRxQFNeSRr4dLXJkzG-wKUlvJtsfv62cdvRrmBv5w,15097
29
29
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
30
30
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
31
31
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -39,22 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
39
39
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
43
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
44
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
42
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=KsM6VlzluTqbodG24IFr3biPxBrLay0z0gmnG0bcU2U,9277
46
+ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=B14IR4mIw6qBVUbiIRdfdUzHMCIJCJ0RFPsYOxA46qc,1776
47
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax1m4WTAnC3gMxBpcvGA7-xTO1Iuw,1802
48
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
46
49
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
50
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
48
51
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
49
- ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
52
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
50
53
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
54
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
52
55
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
53
- ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
56
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
54
57
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
58
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
56
59
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
57
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
60
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
58
61
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
62
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
60
63
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
@@ -78,16 +81,16 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
78
81
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
79
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
80
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
84
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
82
85
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
83
- ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
86
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LUChL5tA7FHL_DlTg5QKvGInmH9AwVVw9a-omcndiz8,2095
84
87
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
85
88
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
86
89
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
90
  ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
88
91
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
89
92
  ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
90
- ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
93
+ ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
91
94
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
92
95
  ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
93
96
  ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
@@ -98,7 +101,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW
98
101
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
99
102
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
100
103
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
101
- ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
104
+ ai_edge_torch/generative/quantize/example.py,sha256=tlACaRsz6lqOxakzpXVFJZYfFKOiFqetcYVJqWVRdPE,1542
102
105
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
103
106
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
104
107
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
@@ -107,8 +110,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
107
110
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
111
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
109
112
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
110
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
111
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
113
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
114
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
112
115
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
113
116
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
114
117
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -116,7 +119,7 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
116
119
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
117
120
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
118
121
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
119
- ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
122
+ ai_edge_torch/generative/utilities/verifier.py,sha256=vU9KgmFS7I9jNS_3H2SWROx-rbNqtMKgQC2MRhdqQ4g,8803
120
123
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
121
124
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
122
125
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -163,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
163
166
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
164
167
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
165
168
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
166
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
167
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
168
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
169
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
170
- ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,
169
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
170
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/METADATA,sha256=SWy7BhOQDe0_SBF17deNndzt1bEYy7iXUxy0KznIPYM,1859
171
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
172
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
173
+ ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/RECORD,,