ai-edge-torch-nightly 0.3.0.dev20240925__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -2
  2. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +203 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  12. ai_edge_torch/generative/examples/phi/phi3.py +15 -21
  13. ai_edge_torch/generative/examples/phi/verify.py +13 -12
  14. ai_edge_torch/generative/examples/phi/verify_phi3.py +13 -12
  15. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  16. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  18. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  19. ai_edge_torch/generative/utilities/verifier.py +130 -114
  20. ai_edge_torch/version.py +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +25 -18
  23. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,33 @@
15
15
 
16
16
  """Verifies the reauthored OpenELM-3B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
-
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
27
+
26
28
  _PROMPTS = flags.DEFINE_multi_string(
27
29
  "prompts",
28
30
  "What is the meaning of life?",
29
31
  "The input prompts to generate answers.",
30
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
31
38
 
32
39
 
33
40
  def main(_):
34
41
  checkpoint = "apple/OpenELM-3B"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
40
45
  )
41
46
 
42
47
  # Locate the cached dir.
@@ -44,18 +49,21 @@ def main(_):
44
49
  checkpoint, transformers.utils.CONFIG_NAME
45
50
  )
46
51
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
52
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
48
53
  reauthored_model = openelm.build_model(reauthored_checkpoint)
49
54
 
50
55
  tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51
- verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  )
60
68
 
61
69
 
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,7 +161,7 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
167
  base=10_000,
@@ -14,14 +14,17 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Verifies the reauthored Phi-2 model."""
17
+ import logging
17
18
 
18
19
  from absl import app
19
20
  from absl import flags
20
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
21
23
  from ai_edge_torch.generative.utilities import verifier
22
24
  import kagglehub
23
25
  import transformers
24
26
 
27
+
25
28
  _PROMPTS = flags.DEFINE_multi_string(
26
29
  "prompts",
27
30
  "Instruct: Write an email about the weather Output:",
@@ -36,25 +39,23 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
39
 
37
40
  def main(_):
38
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
- verifier.log_msg("Loading the original model from", checkpoint)
40
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
41
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
42
- wrapper_model = verifier.ModelWrapper(
43
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
44
- hf_generation_config=generation_config,
45
- )
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
46
44
 
47
- verifier.log_msg("Building the reauthored model from", checkpoint)
45
+ logging.info("Building the reauthored model from: %s", checkpoint)
48
46
  reauthored_model = phi2.build_model(checkpoint)
49
47
 
50
- verifier.log_msg("Loading the tokenizer from", checkpoint)
48
+ logging.info("Loading the tokenizer from: %s", checkpoint)
51
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
52
50
 
53
51
  verifier.verify_reauthored_model(
54
- original_model=wrapper_model,
55
- reauthored_model=reauthored_model,
56
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
57
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
58
59
  atol=1e-03,
59
60
  )
60
61
 
@@ -15,14 +15,17 @@
15
15
 
16
16
  """Verifies the reauthored Phi-3.5 model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "Instruct: Write an email about the weather Output:",
@@ -37,30 +40,28 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
37
40
 
38
41
  def main(_):
39
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
40
- verifier.log_msg("Loading the original model from", checkpoint)
41
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
- wrapper_model = verifier.ModelWrapper(
44
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
- hf_generation_config=generation_config,
46
- )
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
47
45
 
48
46
  # Locate the cached dir.
49
47
  cached_config_file = transformers.utils.cached_file(
50
48
  checkpoint, transformers.utils.CONFIG_NAME
51
49
  )
52
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
52
  reauthored_model = phi3.build_model(reauthored_checkpoint)
55
53
 
56
- verifier.log_msg("Loading the tokenizer from", checkpoint)
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
57
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
56
 
59
57
  verifier.verify_reauthored_model(
60
- original_model=wrapper_model,
61
- reauthored_model=reauthored_model,
62
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
64
65
  )
65
66
 
66
67
 
@@ -15,43 +15,53 @@
15
15
 
16
16
  """Verifies the reauthored SmolLM-135M model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "What is the meaning of life?",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
- )
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
39
46
  # Locate the cached dir.
40
47
  cached_config_file = transformers.utils.cached_file(
41
48
  checkpoint, transformers.utils.CONFIG_NAME
42
49
  )
43
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
45
52
  reauthored_model = smollm.build_model(reauthored_checkpoint)
46
53
 
47
- verifier.log_msg("Loading the tokenizer from", checkpoint)
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
48
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
49
56
 
50
57
  verifier.verify_reauthored_model(
51
- original_model=wrapper_model,
52
- reauthored_model=reauthored_model,
53
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
54
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
55
65
  atol=1e-04,
56
66
  )
57
67
 
@@ -15,45 +15,55 @@
15
15
 
16
16
  """Verifies the reauthored TinyLlama-1.1B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "Show me the program to add 2 and 3.",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
40
46
  )
47
+
41
48
  # Locate the cached dir.
42
49
  cached_config_file = transformers.utils.cached_file(
43
50
  checkpoint, transformers.utils.CONFIG_NAME
44
51
  )
45
52
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
53
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
47
54
  reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
48
55
 
49
- verifier.log_msg("Loading the tokenizer from", checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", checkpoint)
50
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
58
 
52
59
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
67
  atol=1e-04,
58
68
  )
59
69
 
@@ -19,6 +19,7 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
24
25
  from ai_edge_torch.generative.examples.phi import phi3
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
102
103
  pytorch_model = gemma2.Gemma2(config).eval()
103
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
105
115
  @googletest.skipIf(
106
116
  ai_edge_config.Config.use_torch_xla,
107
117
  reason="tests with custom ops are not supported on oss",
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)