ai-edge-torch-nightly 0.3.0.dev20240925__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -2
  2. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +203 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  12. ai_edge_torch/generative/examples/phi/phi3.py +15 -21
  13. ai_edge_torch/generative/examples/phi/verify.py +13 -12
  14. ai_edge_torch/generative/examples/phi/verify_phi3.py +13 -12
  15. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  16. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  18. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  19. ai_edge_torch/generative/utilities/verifier.py +130 -114
  20. ai_edge_torch/version.py +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +25 -18
  23. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -15,28 +15,33 @@
15
15
 
16
16
  """Verifies the reauthored OpenELM-3B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
-
20
20
  from absl import app
21
21
  from absl import flags
22
22
  from ai_edge_torch.generative.examples.openelm import openelm
23
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
24
  from ai_edge_torch.generative.utilities import verifier
24
25
  import transformers
25
26
 
27
+
26
28
  _PROMPTS = flags.DEFINE_multi_string(
27
29
  "prompts",
28
30
  "What is the meaning of life?",
29
31
  "The input prompts to generate answers.",
30
32
  )
33
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
34
+ "max_new_tokens",
35
+ 30,
36
+ "The maximum size of the generated tokens.",
37
+ )
31
38
 
32
39
 
33
40
  def main(_):
34
41
  checkpoint = "apple/OpenELM-3B"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
44
+ checkpoint, trust_remote_code=True
40
45
  )
41
46
 
42
47
  # Locate the cached dir.
@@ -44,18 +49,21 @@ def main(_):
44
49
  checkpoint, transformers.utils.CONFIG_NAME
45
50
  )
46
51
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
47
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
52
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
48
53
  reauthored_model = openelm.build_model(reauthored_checkpoint)
49
54
 
50
55
  tokenizer_checkpoint = "meta-llama/Llama-2-7b-hf"
51
- verifier.log_msg("Loading the tokenizer from", tokenizer_checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", tokenizer_checkpoint)
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  )
60
68
 
61
69
 
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,7 +161,7 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
167
  base=10_000,
@@ -14,14 +14,17 @@
14
14
  # ==============================================================================
15
15
 
16
16
  """Verifies the reauthored Phi-2 model."""
17
+ import logging
17
18
 
18
19
  from absl import app
19
20
  from absl import flags
20
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
21
23
  from ai_edge_torch.generative.utilities import verifier
22
24
  import kagglehub
23
25
  import transformers
24
26
 
27
+
25
28
  _PROMPTS = flags.DEFINE_multi_string(
26
29
  "prompts",
27
30
  "Instruct: Write an email about the weather Output:",
@@ -36,25 +39,23 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
36
39
 
37
40
  def main(_):
38
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
39
- verifier.log_msg("Loading the original model from", checkpoint)
40
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
41
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
42
- wrapper_model = verifier.ModelWrapper(
43
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
44
- hf_generation_config=generation_config,
45
- )
42
+ logging.info("Loading the original model from: %s", checkpoint)
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
46
44
 
47
- verifier.log_msg("Building the reauthored model from", checkpoint)
45
+ logging.info("Building the reauthored model from: %s", checkpoint)
48
46
  reauthored_model = phi2.build_model(checkpoint)
49
47
 
50
- verifier.log_msg("Loading the tokenizer from", checkpoint)
48
+ logging.info("Loading the tokenizer from: %s", checkpoint)
51
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
52
50
 
53
51
  verifier.verify_reauthored_model(
54
- original_model=wrapper_model,
55
- reauthored_model=reauthored_model,
56
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
57
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
58
59
  atol=1e-03,
59
60
  )
60
61
 
@@ -15,14 +15,17 @@
15
15
 
16
16
  """Verifies the reauthored Phi-3.5 model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "Instruct: Write an email about the weather Output:",
@@ -37,30 +40,28 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
37
40
 
38
41
  def main(_):
39
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
40
- verifier.log_msg("Loading the original model from", checkpoint)
41
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
42
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
43
- wrapper_model = verifier.ModelWrapper(
44
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
45
- hf_generation_config=generation_config,
46
- )
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
47
45
 
48
46
  # Locate the cached dir.
49
47
  cached_config_file = transformers.utils.cached_file(
50
48
  checkpoint, transformers.utils.CONFIG_NAME
51
49
  )
52
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
53
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
54
52
  reauthored_model = phi3.build_model(reauthored_checkpoint)
55
53
 
56
- verifier.log_msg("Loading the tokenizer from", checkpoint)
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
57
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
58
56
 
59
57
  verifier.verify_reauthored_model(
60
- original_model=wrapper_model,
61
- reauthored_model=reauthored_model,
62
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
63
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
64
65
  )
65
66
 
66
67
 
@@ -15,43 +15,53 @@
15
15
 
16
16
  """Verifies the reauthored SmolLM-135M model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "What is the meaning of life?",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
38
- )
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
39
46
  # Locate the cached dir.
40
47
  cached_config_file = transformers.utils.cached_file(
41
48
  checkpoint, transformers.utils.CONFIG_NAME
42
49
  )
43
50
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
44
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
51
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
45
52
  reauthored_model = smollm.build_model(reauthored_checkpoint)
46
53
 
47
- verifier.log_msg("Loading the tokenizer from", checkpoint)
54
+ logging.info("Loading the tokenizer from: %s", checkpoint)
48
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
49
56
 
50
57
  verifier.verify_reauthored_model(
51
- original_model=wrapper_model,
52
- reauthored_model=reauthored_model,
53
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
54
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
55
65
  atol=1e-04,
56
66
  )
57
67
 
@@ -15,45 +15,55 @@
15
15
 
16
16
  """Verifies the reauthored TinyLlama-1.1B model."""
17
17
 
18
+ import logging
18
19
  import pathlib
19
20
 
20
21
  from absl import app
21
22
  from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
23
25
  from ai_edge_torch.generative.utilities import verifier
24
26
  import transformers
25
27
 
28
+
26
29
  _PROMPTS = flags.DEFINE_multi_string(
27
30
  "prompts",
28
31
  "Show me the program to add 2 and 3.",
29
32
  "The input prompts to generate answers.",
30
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
31
39
 
32
40
 
33
41
  def main(_):
34
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
- verifier.log_msg("Loading the original model from", checkpoint)
36
- wrapper_model = verifier.ModelWrapper(
37
- model=transformers.AutoModelForCausalLM.from_pretrained(
38
- checkpoint, trust_remote_code=True
39
- ),
43
+ logging.info("Loading the original model from: %s", checkpoint)
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
40
46
  )
47
+
41
48
  # Locate the cached dir.
42
49
  cached_config_file = transformers.utils.cached_file(
43
50
  checkpoint, transformers.utils.CONFIG_NAME
44
51
  )
45
52
  reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
- verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
53
+ logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
47
54
  reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
48
55
 
49
- verifier.log_msg("Loading the tokenizer from", checkpoint)
56
+ logging.info("Loading the tokenizer from: %s", checkpoint)
50
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
58
 
52
59
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
67
  atol=1e-04,
58
68
  )
59
69
 
@@ -19,6 +19,7 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
24
25
  from ai_edge_torch.generative.examples.phi import phi3
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
102
103
  pytorch_model = gemma2.Gemma2(config).eval()
103
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
105
115
  @googletest.skipIf(
106
116
  ai_edge_config.Config.use_torch_xla,
107
117
  reason="tests with custom ops are not supported on oss",
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)