ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240920__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +7 -3
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +6 -4
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
  5. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  6. ai_edge_torch/generative/examples/openelm/verify.py +63 -0
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
  8. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  9. ai_edge_torch/generative/examples/phi/verify.py +63 -0
  10. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
  11. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  12. ai_edge_torch/generative/examples/smollm/verify.py +60 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +62 -0
  16. ai_edge_torch/generative/layers/builder.py +3 -1
  17. ai_edge_torch/generative/layers/model_config.py +3 -0
  18. ai_edge_torch/generative/layers/normalization.py +31 -20
  19. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  20. ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
  21. ai_edge_torch/generative/layers/unet/model_config.py +1 -0
  22. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +2 -2
  24. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
  25. ai_edge_torch/generative/utilities/verifier.py +249 -0
  26. ai_edge_torch/model.py +7 -4
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/RECORD +32 -27
  30. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
23
23
  from ai_edge_torch._convert import conversion_utils
24
24
  from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
- import tensorflow as tf
27
26
  import torch
28
27
  from torch import nn
29
28
  import torchvision
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
32
32
 
33
33
 
34
34
  @dataclasses.dataclass
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
466
466
  np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
467
467
  np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
468
468
 
469
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
469
+ interpreter = tfl_interpreter.Interpreter(
470
+ model_content=edge_model._tflite_model
471
+ )
470
472
  runner = interpreter.get_signature_runner("serving_default")
471
473
  output_details = runner.get_output_details()
472
474
  self.assertIn("x", output_details.keys())
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
477
479
  def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
478
480
  model.eval()
479
481
  edge_model = ai_edge_torch.convert(model, args, kwargs)
480
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
482
+ interpreter = tfl_interpreter.Interpreter(
483
+ model_content=edge_model._tflite_model
484
+ )
481
485
  runner = interpreter.get_signature_runner("serving_default")
482
486
  input_details = runner.get_input_details()
483
487
  self.assertEqual(input_details.keys(), flat_inputs.keys())
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = gemma2.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = gemma.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/openelm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = openelm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
206
  loader.load(model, strict=False)
211
207
  model.eval()
212
208
  return model
213
-
214
-
215
- def define_and_run(checkpoint_path: str) -> None:
216
- """Instantiates and runs an OpenELM model."""
217
-
218
- current_dir = pathlib.Path(__file__).parent.resolve()
219
- openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
- kv_cache_max_len = 1024
221
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
- tokens[0, :4] = idx
225
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
- kv = kv_utils.KVCache.from_model_config(model.config)
227
- output = model.forward(tokens, input_pos, kv)
228
- assert torch.allclose(
229
- openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
- )
231
-
232
-
233
- if __name__ == "__main__":
234
- input_checkpoint_path = os.path.join(
235
- pathlib.Path.home(), "Downloads/llm_data/openelm"
236
- )
237
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "apple/OpenELM-3B"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
40
+ )
41
+
42
+ # Locate the cached dir.
43
+ cached_config_file = transformers.utils.cached_file(
44
+ checkpoint, transformers.utils.CONFIG_NAME
45
+ )
46
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
48
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
49
+
50
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51
+ verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
52
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
53
+
54
+ verifier.verify_reauthored_model(
55
+ original_model=wrapper_model,
56
+ reauthored_model=reauthored_model,
57
+ tokenizer=tokenizer,
58
+ prompts=_PROMPTS.value,
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ app.run(main)
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/phi2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = phi2.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
139
  intermediate_size=10240,
144
140
  use_bias=True,
145
141
  )
146
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
142
+ norm_config = cfg.NormalizationConfig(
143
+ type=cfg.NormalizationType.LAYER_NORM,
144
+ use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
145
+ )
147
146
  block_config = cfg.TransformerBlockConfig(
148
147
  attn_config=attn_config,
149
148
  ff_config=ff_config,
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
182
181
  loader.load(model)
183
182
  model.eval()
184
183
  return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str) -> None:
188
- """Instantiates and runs a Phi-2 model."""
189
-
190
- current_dir = pathlib.Path(__file__).parent.resolve()
191
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
- kv_cache_max_len = 1024
193
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
- kv = kv_utils.KVCache.from_model_config(model.config)
199
- output = model.forward(tokens, input_pos, kv)
200
- print("comparing with goldens..")
201
- assert torch.allclose(
202
- phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
- )
204
-
205
-
206
- if __name__ == "__main__":
207
- input_checkpoint_path = os.path.join(
208
- pathlib.Path.home(), "Downloads/llm_data/phi2"
209
- )
210
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.phi import phi2
21
+ from ai_edge_torch.generative.utilities import verifier
22
+ import kagglehub
23
+ import transformers
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "Instruct: Write an email about the weather Output:",
28
+ "The input prompts to generate answers.",
29
+ )
30
+
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
+ verifier.log_msg("Loading the original model from", checkpoint)
40
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
41
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
42
+ wrapper_model = verifier.ModelWrapper(
43
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
44
+ hf_generation_config=generation_config,
45
+ )
46
+
47
+ verifier.log_msg("Building the reauthored model from", checkpoint)
48
+ reauthored_model = phi2.build_model(checkpoint)
49
+
50
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
51
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
52
+
53
+ verifier.verify_reauthored_model(
54
+ original_model=wrapper_model,
55
+ reauthored_model=reauthored_model,
56
+ tokenizer=tokenizer,
57
+ prompts=_PROMPTS.value,
58
+ atol=1e-03,
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ app.run(main)
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/smollm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = smollm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -16,15 +16,10 @@
16
16
  """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
- import os
20
- import pathlib
21
19
 
22
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.model_config as cfg
25
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
- import numpy as np
27
- import torch
28
23
  from torch import nn
29
24
 
30
25
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
104
99
  loader.load(model, strict=False)
105
100
  model.eval()
106
101
  return model
107
-
108
-
109
- def define_and_run(checkpoint_path: str) -> None:
110
- """Instantiates and runs a SmolLM model."""
111
-
112
- current_dir = pathlib.Path(__file__).parent.resolve()
113
- smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
114
- kv_cache_max_len = 1024
115
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
116
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
118
- tokens[0, :4] = idx
119
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
120
- kv = kv_utils.KVCache.from_model_config(model.config)
121
- output = model.forward(tokens, input_pos, kv)
122
- assert torch.allclose(
123
- smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
124
- )
125
-
126
-
127
- if __name__ == "__main__":
128
- input_checkpoint_path = os.path.join(
129
- pathlib.Path.home(), "Downloads/llm_data/smollm"
130
- )
131
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored SmolLM-135M model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "HuggingFaceTB/SmolLM-135M"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
+ )
39
+ # Locate the cached dir.
40
+ cached_config_file = transformers.utils.cached_file(
41
+ checkpoint, transformers.utils.CONFIG_NAME
42
+ )
43
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
45
+ reauthored_model = smollm.build_model(reauthored_checkpoint)
46
+
47
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
48
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
49
+
50
+ verifier.verify_reauthored_model(
51
+ original_model=wrapper_model,
52
+ reauthored_model=reauthored_model,
53
+ tokenizer=tokenizer,
54
+ prompts=_PROMPTS.value,
55
+ atol=1e-04,
56
+ )
57
+
58
+
59
+ if __name__ == "__main__":
60
+ app.run(main)
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = tiny_llama.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
179
175
  loader.load(model)
180
176
  model.eval()
181
177
  return model
182
-
183
-
184
- def define_and_run(checkpoint_path: str) -> None:
185
- """Instantiates and runs a TinyLlama model."""
186
-
187
- current_dir = pathlib.Path(__file__).parent.resolve()
188
- tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
189
- kv_cache_max_len = 1024
190
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
- tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
- kv = kv_utils.KVCache.from_model_config(model.config)
196
- output = model.forward(tokens, input_pos, kv)
197
- assert torch.allclose(
198
- tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
199
- )
200
-
201
-
202
- if __name__ == "__main__":
203
- input_checkpoint_path = os.path.join(
204
- pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
- )
206
- define_and_run(input_checkpoint_path)