ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240920__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +7 -3
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +6 -4
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
  5. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  6. ai_edge_torch/generative/examples/openelm/verify.py +63 -0
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
  8. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  9. ai_edge_torch/generative/examples/phi/verify.py +63 -0
  10. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
  11. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  12. ai_edge_torch/generative/examples/smollm/verify.py +60 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +62 -0
  16. ai_edge_torch/generative/layers/builder.py +3 -1
  17. ai_edge_torch/generative/layers/model_config.py +3 -0
  18. ai_edge_torch/generative/layers/normalization.py +31 -20
  19. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  20. ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
  21. ai_edge_torch/generative/layers/unet/model_config.py +1 -0
  22. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +2 -2
  24. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
  25. ai_edge_torch/generative/utilities/verifier.py +249 -0
  26. ai_edge_torch/model.py +7 -4
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/RECORD +32 -27
  30. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
23
23
  from ai_edge_torch._convert import conversion_utils
24
24
  from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
- import tensorflow as tf
27
26
  import torch
28
27
  from torch import nn
29
28
  import torchvision
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
32
32
 
33
33
 
34
34
  @dataclasses.dataclass
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
466
466
  np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
467
467
  np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
468
468
 
469
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
469
+ interpreter = tfl_interpreter.Interpreter(
470
+ model_content=edge_model._tflite_model
471
+ )
470
472
  runner = interpreter.get_signature_runner("serving_default")
471
473
  output_details = runner.get_output_details()
472
474
  self.assertIn("x", output_details.keys())
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
477
479
  def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
478
480
  model.eval()
479
481
  edge_model = ai_edge_torch.convert(model, args, kwargs)
480
- interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
482
+ interpreter = tfl_interpreter.Interpreter(
483
+ model_content=edge_model._tflite_model
484
+ )
481
485
  runner = interpreter.get_signature_runner("serving_default")
482
486
  input_details = runner.get_input_details()
483
487
  self.assertEqual(input_details.keys(), flat_inputs.keys())
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = gemma2.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/gemma_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = gemma.build_2b_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/openelm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = openelm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building an OpenELM model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -210,28 +206,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
210
206
  loader.load(model, strict=False)
211
207
  model.eval()
212
208
  return model
213
-
214
-
215
- def define_and_run(checkpoint_path: str) -> None:
216
- """Instantiates and runs an OpenELM model."""
217
-
218
- current_dir = pathlib.Path(__file__).parent.resolve()
219
- openelm_goldens = torch.load(current_dir / "openelm_lm_logits.pt")
220
- kv_cache_max_len = 1024
221
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
222
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
223
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
224
- tokens[0, :4] = idx
225
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
226
- kv = kv_utils.KVCache.from_model_config(model.config)
227
- output = model.forward(tokens, input_pos, kv)
228
- assert torch.allclose(
229
- openelm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
230
- )
231
-
232
-
233
- if __name__ == "__main__":
234
- input_checkpoint_path = os.path.join(
235
- pathlib.Path.home(), "Downloads/llm_data/openelm"
236
- )
237
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored OpenELM-3B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "apple/OpenELM-3B"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
40
+ )
41
+
42
+ # Locate the cached dir.
43
+ cached_config_file = transformers.utils.cached_file(
44
+ checkpoint, transformers.utils.CONFIG_NAME
45
+ )
46
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
48
+ reauthored_model = openelm.build_model(reauthored_checkpoint)
49
+
50
+ tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51
+ verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
52
+ tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
53
+
54
+ verifier.verify_reauthored_model(
55
+ original_model=wrapper_model,
56
+ reauthored_model=reauthored_model,
57
+ tokenizer=tokenizer,
58
+ prompts=_PROMPTS.value,
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ app.run(main)
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/phi2_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = phi2.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a Phi-2 model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -143,7 +139,10 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
143
139
  intermediate_size=10240,
144
140
  use_bias=True,
145
141
  )
146
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
142
+ norm_config = cfg.NormalizationConfig(
143
+ type=cfg.NormalizationType.LAYER_NORM,
144
+ use_input_shape=False, # Phi-2 does layer-norm with the weight shape.
145
+ )
147
146
  block_config = cfg.TransformerBlockConfig(
148
147
  attn_config=attn_config,
149
148
  ff_config=ff_config,
@@ -182,29 +181,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
182
181
  loader.load(model)
183
182
  model.eval()
184
183
  return model
185
-
186
-
187
- def define_and_run(checkpoint_path: str) -> None:
188
- """Instantiates and runs a Phi-2 model."""
189
-
190
- current_dir = pathlib.Path(__file__).parent.resolve()
191
- phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
192
- kv_cache_max_len = 1024
193
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
194
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
195
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
196
- tokens[0, :4] = idx
197
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
198
- kv = kv_utils.KVCache.from_model_config(model.config)
199
- output = model.forward(tokens, input_pos, kv)
200
- print("comparing with goldens..")
201
- assert torch.allclose(
202
- phi2_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
203
- )
204
-
205
-
206
- if __name__ == "__main__":
207
- input_checkpoint_path = os.path.join(
208
- pathlib.Path.home(), "Downloads/llm_data/phi2"
209
- )
210
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,63 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored Phi-2 model."""
17
+
18
+ from absl import app
19
+ from absl import flags
20
+ from ai_edge_torch.generative.examples.phi import phi2
21
+ from ai_edge_torch.generative.utilities import verifier
22
+ import kagglehub
23
+ import transformers
24
+
25
+ _PROMPTS = flags.DEFINE_multi_string(
26
+ "prompts",
27
+ "Instruct: Write an email about the weather Output:",
28
+ "The input prompts to generate answers.",
29
+ )
30
+
31
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
32
+ "max_new_tokens",
33
+ 30,
34
+ "The maximum size of the generated tokens.",
35
+ )
36
+
37
+ def main(_):
38
+ checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
+ verifier.log_msg("Loading the original model from", checkpoint)
40
+ generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
41
+ generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
42
+ wrapper_model = verifier.ModelWrapper(
43
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
44
+ hf_generation_config=generation_config,
45
+ )
46
+
47
+ verifier.log_msg("Building the reauthored model from", checkpoint)
48
+ reauthored_model = phi2.build_model(checkpoint)
49
+
50
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
51
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
52
+
53
+ verifier.verify_reauthored_model(
54
+ original_model=wrapper_model,
55
+ reauthored_model=reauthored_model,
56
+ tokenizer=tokenizer,
57
+ prompts=_PROMPTS.value,
58
+ atol=1e-03,
59
+ )
60
+
61
+
62
+ if __name__ == "__main__":
63
+ app.run(main)
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/smollm_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = smollm.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -16,15 +16,10 @@
16
16
  """Example of building a SmolLM model."""
17
17
 
18
18
  import copy
19
- import os
20
- import pathlib
21
19
 
22
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.model_config as cfg
25
22
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
- import numpy as np
27
- import torch
28
23
  from torch import nn
29
24
 
30
25
  TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
@@ -104,28 +99,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
104
99
  loader.load(model, strict=False)
105
100
  model.eval()
106
101
  return model
107
-
108
-
109
- def define_and_run(checkpoint_path: str) -> None:
110
- """Instantiates and runs a SmolLM model."""
111
-
112
- current_dir = pathlib.Path(__file__).parent.resolve()
113
- smollm_goldens = torch.load(current_dir / "smollm_lm_logits.pt")
114
- kv_cache_max_len = 1024
115
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
116
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
118
- tokens[0, :4] = idx
119
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
120
- kv = kv_utils.KVCache.from_model_config(model.config)
121
- output = model.forward(tokens, input_pos, kv)
122
- assert torch.allclose(
123
- smollm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
124
- )
125
-
126
-
127
- if __name__ == "__main__":
128
- input_checkpoint_path = os.path.join(
129
- pathlib.Path.home(), "Downloads/llm_data/smollm"
130
- )
131
- define_and_run(input_checkpoint_path)
@@ -0,0 +1,60 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored SmolLM-135M model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.smollm import smollm
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "What is the meaning of life?",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "HuggingFaceTB/SmolLM-135M"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
+ )
39
+ # Locate the cached dir.
40
+ cached_config_file = transformers.utils.cached_file(
41
+ checkpoint, transformers.utils.CONFIG_NAME
42
+ )
43
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
45
+ reauthored_model = smollm.build_model(reauthored_checkpoint)
46
+
47
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
48
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
49
+
50
+ verifier.verify_reauthored_model(
51
+ original_model=wrapper_model,
52
+ reauthored_model=reauthored_model,
53
+ tokenizer=tokenizer,
54
+ prompts=_PROMPTS.value,
55
+ atol=1e-04,
56
+ )
57
+
58
+
59
+ if __name__ == "__main__":
60
+ app.run(main)
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
30
30
  )
31
31
  _TFLITE_PATH = flags.DEFINE_string(
32
32
  'tflite_path',
33
- '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
33
+ '/tmp/',
34
34
  'The tflite file path to export.',
35
35
  )
36
36
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
37
  'prefill_seq_len',
38
- 512,
38
+ 1024,
39
39
  'The maximum size of prefill input tensor.',
40
40
  )
41
41
  _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
42
  'kv_cache_max_len',
43
- 1024,
43
+ 1280,
44
44
  'The maximum size of KV cache buffer, including both prefill and decode.',
45
45
  )
46
46
  _QUANTIZE = flags.DEFINE_bool(
@@ -54,9 +54,11 @@ def main(_):
54
54
  pytorch_model = tiny_llama.build_model(
55
55
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
56
  )
57
+ quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
+ output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
57
59
  converter.convert_to_tflite(
58
60
  pytorch_model,
59
- tflite_path=_TFLITE_PATH.value,
61
+ tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
60
62
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
63
  quantize=_QUANTIZE.value,
62
64
  )
@@ -15,16 +15,12 @@
15
15
 
16
16
  """Example of building a TinyLlama model."""
17
17
 
18
- import os
19
- import pathlib
20
-
21
18
  from ai_edge_torch.generative.layers import attention
22
19
  from ai_edge_torch.generative.layers import builder
23
20
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
22
  import ai_edge_torch.generative.layers.model_config as cfg
26
23
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
- import numpy as np
28
24
  import torch
29
25
  from torch import nn
30
26
 
@@ -179,28 +175,3 @@ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
179
175
  loader.load(model)
180
176
  model.eval()
181
177
  return model
182
-
183
-
184
- def define_and_run(checkpoint_path: str) -> None:
185
- """Instantiates and runs a TinyLlama model."""
186
-
187
- current_dir = pathlib.Path(__file__).parent.resolve()
188
- tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
189
- kv_cache_max_len = 1024
190
- model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
191
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
192
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
193
- tokens[0, :4] = idx
194
- input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
195
- kv = kv_utils.KVCache.from_model_config(model.config)
196
- output = model.forward(tokens, input_pos, kv)
197
- assert torch.allclose(
198
- tiny_llama_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
199
- )
200
-
201
-
202
- if __name__ == "__main__":
203
- input_checkpoint_path = os.path.join(
204
- pathlib.Path.home(), "Downloads/llm_data/tiny_llama"
205
- )
206
- define_and_run(input_checkpoint_path)