ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (23) hide show
  1. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  2. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  3. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  4. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  5. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  6. ai_edge_torch/generative/examples/llama/llama.py +203 -0
  7. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  8. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  9. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  10. ai_edge_torch/generative/examples/phi/phi3.py +15 -21
  11. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  12. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  13. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  14. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  15. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  16. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  17. ai_edge_torch/generative/utilities/verifier.py +117 -97
  18. ai_edge_torch/version.py +1 -1
  19. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +23 -16
  21. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -97,15 +97,15 @@ ROPE_SHORT_FACTOR = [
97
97
  ]
98
98
 
99
99
 
100
- def build_rope_cache(
100
+ def _build_rope_cache(
101
101
  size: int,
102
102
  dim: int,
103
- base: int = 10000,
104
- condense_ratio: int = 1,
105
- dtype: torch.dtype = torch.float32,
106
- device: torch.device = None,
107
- theta_factors: torch.Tensor = None,
108
- scale: float = 1.0,
103
+ base: int,
104
+ condense_ratio: int,
105
+ dtype: torch.dtype,
106
+ device: torch.device,
107
+ theta_factors: torch.Tensor,
108
+ scale: float,
109
109
  ) -> Tuple[torch.Tensor, torch.Tensor]:
110
110
  """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
111
111
 
@@ -116,26 +116,20 @@ def build_rope_cache(
116
116
  Args:
117
117
  size (int): The size of the built cache.
118
118
  dim (int): Each sequence's dimmension.
119
- base (int, optional): Rope base value. Defaults to 10000.
119
+ base (int, optional): Rope base value.
120
120
  condense_ratio (int, optional): The ratio by which sequence indicies are
121
- condensed. Defaults to 1.
122
- dtype (torch.dtype, optional): Output tensor's data type. Defaults to
123
- torch.float32.
124
- device (torch.device, optional): Output tensor's data type. Defaults to
125
- None in which case "cpu" is used.
121
+ condensed.
122
+ dtype (torch.dtype, optional): Output tensor's data type.
123
+ device (torch.device, optional): Output tensor's data type.
126
124
  theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
127
- scale the theta values. Defaults to None.
128
- scale (float, optional): A float used to scale the rope values. Defaults
129
- to 1.0.
125
+ scale the theta values.
126
+ scale (float, optional): A float used to scale the rope values.
130
127
 
131
128
  Returns:
132
129
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
133
130
  """
134
- if device is None:
135
- device = torch.device('cpu')
136
131
  theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
137
- if theta_factors is not None:
138
- theta = theta / theta_factors
132
+ theta = theta / theta_factors
139
133
  seq_idx = torch.arange(size) / condense_ratio
140
134
  idx_theta = torch.outer(seq_idx, theta)
141
135
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
@@ -167,7 +161,7 @@ class Phi3_5Mini(nn.Module):
167
161
  config.final_norm_config,
168
162
  )
169
163
  attn_config = block_config.attn_config
170
- self.rope_cache = build_rope_cache(
164
+ self.rope_cache = _build_rope_cache(
171
165
  size=config.kv_cache_max,
172
166
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
173
167
  base=10_000,
@@ -19,6 +19,7 @@ import logging
19
19
  from absl import app
20
20
  from absl import flags
21
21
  from ai_edge_torch.generative.examples.phi import phi2
22
+ from ai_edge_torch.generative.utilities import transformers_verifier
22
23
  from ai_edge_torch.generative.utilities import verifier
23
24
  import kagglehub
24
25
  import transformers
@@ -39,12 +40,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
39
40
  def main(_):
40
41
  checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
41
42
  logging.info("Loading the original model from: %s", checkpoint)
42
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
43
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
44
- wrapper_model = verifier.ModelWrapper(
45
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
46
- hf_generation_config=generation_config,
47
- )
43
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
48
44
 
49
45
  logging.info("Building the reauthored model from: %s", checkpoint)
50
46
  reauthored_model = phi2.build_model(checkpoint)
@@ -53,10 +49,13 @@ def main(_):
53
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
54
50
 
55
51
  verifier.verify_reauthored_model(
56
- original_model=wrapper_model,
57
- reauthored_model=reauthored_model,
58
- tokenizer=tokenizer,
52
+ original_model=transformers_verifier.TransformersModelWrapper(
53
+ original_model
54
+ ),
55
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
56
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
59
57
  generate_prompts=_PROMPTS.value,
58
+ max_new_tokens=_MAX_NEW_TOKENS.value,
60
59
  atol=1e-03,
61
60
  )
62
61
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -40,12 +41,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
40
41
  def main(_):
41
42
  checkpoint = "microsoft/Phi-3.5-mini-instruct"
42
43
  logging.info("Loading the original model from: %s", checkpoint)
43
- generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
44
- generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
45
- wrapper_model = verifier.ModelWrapper(
46
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
47
- hf_generation_config=generation_config,
48
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
49
45
 
50
46
  # Locate the cached dir.
51
47
  cached_config_file = transformers.utils.cached_file(
@@ -59,10 +55,13 @@ def main(_):
59
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
60
56
 
61
57
  verifier.verify_reauthored_model(
62
- original_model=wrapper_model,
63
- reauthored_model=reauthored_model,
64
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
65
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
66
65
  )
67
66
 
68
67
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,14 +31,18 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "What is the meaning of life?",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "HuggingFaceTB/SmolLM-135M"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
40
- )
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint)
45
+
41
46
  # Locate the cached dir.
42
47
  cached_config_file = transformers.utils.cached_file(
43
48
  checkpoint, transformers.utils.CONFIG_NAME
@@ -50,10 +55,13 @@ def main(_):
50
55
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
56
 
52
57
  verifier.verify_reauthored_model(
53
- original_model=wrapper_model,
54
- reauthored_model=reauthored_model,
55
- tokenizer=tokenizer,
58
+ original_model=transformers_verifier.TransformersModelWrapper(
59
+ original_model
60
+ ),
61
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
62
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
56
63
  generate_prompts=_PROMPTS.value,
64
+ max_new_tokens=_MAX_NEW_TOKENS.value,
57
65
  atol=1e-04,
58
66
  )
59
67
 
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "Show me the program to add 2 and 3.",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(
40
- checkpoint, trust_remote_code=True
41
- ),
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
42
46
  )
47
+
43
48
  # Locate the cached dir.
44
49
  cached_config_file = transformers.utils.cached_file(
45
50
  checkpoint, transformers.utils.CONFIG_NAME
@@ -52,10 +57,13 @@ def main(_):
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  atol=1e-04,
60
68
  )
61
69
 
@@ -19,6 +19,7 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
24
25
  from ai_edge_torch.generative.examples.phi import phi3
@@ -102,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
102
103
  pytorch_model = gemma2.Gemma2(config).eval()
103
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
105
115
  @googletest.skipIf(
106
116
  ai_edge_config.Config.use_torch_xla,
107
117
  reason="tests with custom ops are not supported on oss",
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)
@@ -16,111 +16,129 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import List, Optional, Union
19
+ from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import torch
23
- import transformers
24
23
 
25
24
 
26
25
  class ModelWrapper(torch.nn.Module):
27
- """A wrapper for the model to be verified, this could be a HuggingFace model
26
+ """A wrapper for the model to be verified.
28
27
 
29
- or a regular PyTorch model.
28
+ It unifies the interface of forward() and generate() of models for the
29
+ verification to call.
30
30
  """
31
31
 
32
- def __init__(
33
- self,
34
- model: torch.nn.Module,
35
- model_format: str = "huggingface",
36
- hf_generation_config: Optional[transformers.GenerationConfig] = None,
37
- ):
32
+ def __init__(self, model: torch.nn.Module):
38
33
  """Initializes the wrapper.
39
34
 
40
35
  Args:
41
- model (torch.nn.Module): The original model. This could be a model built
42
- from HuggingFace transformers, or a regular PyTorch model.
43
- model_format (str): The format of the model. It should be either
44
- "huggingface" or "pytorch".
45
- hf_generation_config (transformers.GenerationConfig): The HuggingFace
46
- generation config. This config will only be used if the underlying model
47
- is built from HuggingFace transformers.
36
+ model (torch.nn.Module): The model which might have different interfaces
37
+ of forward() and generate(). It could be a model built from HuggingFace
38
+ transformers, a regular PyTorch model, or a model re-authored with
39
+ ai_edge_torch Generative API.
48
40
  """
49
41
  super().__init__()
50
42
  self.model = model
51
- self.model_format = model_format
52
- self.hf_generation_config = hf_generation_config
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ """Gets output logits by forwarding the input tokens.
46
+
47
+ Args:
48
+ tokens (torch.Tensor): The input tokens to forward. Its dimension is
49
+ expected to be (batch_size=1, kv_cache_max_len).
50
+
51
+ Returns:
52
+ The output logits.
53
+ """
54
+ raise NotImplementedError("forward() is not implemented.")
53
55
 
54
56
  def generate(
55
- self, inputs: torch.Tensor
56
- ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
57
- if self.model_format == "huggingface":
58
- return self.model.generate(
59
- inputs=inputs, generation_config=self.hf_generation_config
60
- )
61
- else:
62
- raise NotImplementedError(
63
- "generate() is not implemented for model format: %s"
64
- % self.model_format
65
- )
57
+ self, prompts: torch.Tensor, max_new_tokens: int
58
+ ) -> torch.IntTensor:
59
+ """Returns the response token IDs to the given prompts tensor.
66
60
 
67
- def forward(
68
- self,
69
- inputs: torch.Tensor,
70
- ):
71
- return self.model.forward(inputs)
61
+ The maximum number of tokens to generate might be set by subclasses.
72
62
 
63
+ Args:
64
+ prompts (torch.Tensor): The input token IDs to generate with. Its shape is
65
+ expected to be (batch_size=1, input_ids_len).
66
+ max_new_tokens (int): The maximum number of response token IDs to
67
+ generate.
68
+
69
+ Returns:
70
+ The tensor of response token IDs with shape of (batch_size=1,
71
+ response_ids_len).
72
+ """
73
+ raise NotImplementedError("generate() is not implemented.")
73
74
 
74
- def forward(
75
- model: torch.nn.Module,
76
- tokens: torch.Tensor,
77
- kv_cache: kv_utils.KVCache,
78
- ) -> tuple[torch.Tensor, kv_utils.KVCache]:
79
- """Forwards the model reauthored with ai_edge_torch Generative API.
80
75
 
81
- Args:
82
- model (torch.nn.Module): The model to forward. It should be a model built
83
- with ai_edge_torch Generative API.
84
- tokens (torch.Tensor): The input tokens to forward.
85
- kv_cache (KVCache): The KV cache to forward.
76
+ class ReauthoredModelWrapper(ModelWrapper):
77
+ """A wrapper for the model reauthored with ai_edge_torch Generative API."""
86
78
 
87
- Returns:
88
- The output logits and the updated KV cache.
89
- """
90
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
91
- output = model.forward(tokens, input_pos, kv_cache)
92
- return output["logits"], output["kv_cache"]
79
+ def _init_kv_cache(self):
80
+ """Returns an initialized KV cache."""
81
+ return kv_utils.KVCache.from_model_config(self.model.config)
93
82
 
83
+ def _forward_with_kv_cache(
84
+ self,
85
+ tokens: torch.Tensor,
86
+ kv_cache: kv_utils.KVCache,
87
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
+ """Forwards the model and updates an external KV cache.
94
89
 
95
- def generate(
96
- model: torch.nn.Module, prompts: torch.Tensor, response_len: int
97
- ) -> torch.Tensor:
98
- """Generates the response to the prompts.
90
+ Args:
91
+ tokens (torch.Tensor): The input tokens to forward.
92
+ kv_cache (KVCache): The KV cache to forward.
99
93
 
100
- It appends tokens output by the model to the prompts and feeds them back to
101
- the model up to decode_len.
94
+ Returns:
95
+ The output logits and the updated KV cache.
96
+ """
97
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
+ output = self.model.forward(tokens, input_pos, kv_cache)
99
+ return output["logits"], output["kv_cache"]
102
100
 
103
- Args:
104
- model (torch.nn.Module): The model to generate. It should be a model built
105
- with ai_edge_torch Generative API.
106
- prompts (torch.Tensor): The prompts to generate.
107
- response_len (int): The number of tokens to generate.
101
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
+ logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
103
+ return logits
108
104
 
109
- Returns:
110
- The generated tokens.
111
- """
112
- input_ids = prompts[0].int().tolist()
113
- kv_cache = kv_utils.KVCache.from_model_config(model.config)
114
- for _ in range(response_len - len(input_ids)):
115
- logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
116
- generated_token = logits[0][-1].argmax().item()
117
- input_ids.append(generated_token)
118
- return torch.tensor([input_ids])
105
+ def generate(
106
+ self, prompts: torch.Tensor, max_new_tokens: int
107
+ ) -> torch.IntTensor:
108
+ input_ids = prompts[0].int().tolist()
109
+ kv_cache = self._init_kv_cache()
110
+ for _ in range(max_new_tokens):
111
+ tokens = torch.tensor([input_ids])
112
+ logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
113
+ generated_token = logits[0][-1].argmax().item()
114
+ input_ids.append(generated_token)
115
+ return torch.tensor([input_ids])
116
+
117
+
118
+ class TokenizerWrapper(torch.nn.Module):
119
+ """A wrapper for the tokenizer used for verification."""
120
+
121
+ def __init__(self, tokenizer: torch.nn.Module):
122
+ """Initializes the wrapper.
123
+
124
+ Args:
125
+ tokenizer (torch.nn.Module): The tokenizer to wrap.
126
+ """
127
+ super().__init__()
128
+ self.tokenizer = tokenizer
129
+
130
+ def encode(self, prompts: str) -> torch.Tensor:
131
+ """Encodes the prompts to token IDs."""
132
+ return self.tokenizer.encode(prompts, return_tensors="pt")
133
+
134
+ def decode(self, token_ids: torch.Tensor) -> str:
135
+ """Decodes the token IDs to a string."""
136
+ return self.tokenizer.decode(token_ids)
119
137
 
120
138
 
121
139
  def verify_with_input_ids(
122
140
  original_model: ModelWrapper,
123
- reauthored_model: torch.nn.Module,
141
+ reauthored_model: ReauthoredModelWrapper,
124
142
  input_ids: List[int],
125
143
  kv_cache_max_len: int = 1024,
126
144
  rtol: float = 1e-05,
@@ -132,8 +150,8 @@ def verify_with_input_ids(
132
150
 
133
151
  Args:
134
152
  original_model (ModelWrapper): The original model.
135
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
136
- Generative API.
153
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
154
+ ai_edge_torch Generative API.
137
155
  input_ids (List[int]): The input token IDs to forward with.
138
156
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
139
157
  rtol (float): The relative tolerance for the comparison.
@@ -147,13 +165,12 @@ def verify_with_input_ids(
147
165
 
148
166
  logging.info("Forwarding the original model...")
149
167
  outputs_original = original_model.forward(tokens)
150
- logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
168
+ logits_original = outputs_original[0, len(input_ids) - 1, :]
151
169
  logging.info("logits_original: %s", logits_original)
152
170
 
153
171
  logging.info("Forwarding the reauthored model...")
154
- kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
155
- outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
156
- logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
172
+ outputs_reauthored = reauthored_model.forward(tokens)
173
+ logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
157
174
  logging.info("logits_reauthored: %s", logits_reauthored)
158
175
 
159
176
  return torch.allclose(
@@ -163,9 +180,10 @@ def verify_with_input_ids(
163
180
 
164
181
  def verify_model_with_prompts(
165
182
  original_model: ModelWrapper,
166
- reauthored_model: torch.nn.Module,
167
- tokenizer: torch.nn.Module,
183
+ reauthored_model: ReauthoredModelWrapper,
184
+ tokenizer: TokenizerWrapper,
168
185
  prompts: str,
186
+ max_new_tokens: int,
169
187
  ) -> bool:
170
188
  """Verifies if the model reauthored generates the same answer of the oringal.
171
189
 
@@ -174,24 +192,24 @@ def verify_model_with_prompts(
174
192
 
175
193
  Args:
176
194
  original_model (ModelWrapper): The original model.
177
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
178
- Generative API.
179
- tokenizer (torch.nn.Module): The tokenizer.
195
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
196
+ ai_edge_torch Generative API.
197
+ tokenizer (TokenizerWrapper): The tokenizer.
180
198
  prompts (str): The input prompts to generate answers.
199
+ max_new_tokens (int): The maximum number of new tokens to generate.
181
200
 
182
201
  Returns:
183
202
  True if the model reauthored generates the same answer of the original.
184
203
  """
185
- prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
204
+ prompt_tokens = tokenizer.encode(prompts)
186
205
 
187
206
  logging.info("Generating answer with the original model...")
188
- outputs_original = original_model.generate(prompt_tokens)
207
+ outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
189
208
  response_original = tokenizer.decode(outputs_original[0])
190
209
  logging.info("outputs_from_original_model: [[%s]]", response_original)
191
210
 
192
211
  logging.info("Generating answer with the reauthored model...")
193
- generate_len = len(outputs_original[0])
194
- outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
212
+ outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
195
213
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
196
214
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
197
215
 
@@ -200,9 +218,10 @@ def verify_model_with_prompts(
200
218
 
201
219
  def verify_reauthored_model(
202
220
  original_model: ModelWrapper,
203
- reauthored_model: torch.nn.Module,
204
- tokenizer: torch.nn.Module,
221
+ reauthored_model: ReauthoredModelWrapper,
222
+ tokenizer: TokenizerWrapper,
205
223
  generate_prompts: List[str],
224
+ max_new_tokens: int = 30,
206
225
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
207
226
  rtol: float = 1e-05,
208
227
  atol: float = 1e-05,
@@ -219,10 +238,11 @@ def verify_reauthored_model(
219
238
 
220
239
  Args:
221
240
  original_model (ModelWrapper): The original model.
222
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
223
- Generative API.
224
- tokenizer (torch.nn.Module): The tokenizer.
241
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
242
+ ai_edge_torch Generative API.
243
+ tokenizer (TokenizerWrapper): The tokenizer.
225
244
  generate_prompts (List[str]): List of the input prompts to generate answers.
245
+ max_new_tokens (int): The maximum number of new tokens to generate.
226
246
  forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
227
247
  forward with.
228
248
  rtol (float): The relative tolerance for the comparison.
@@ -235,13 +255,13 @@ def verify_reauthored_model(
235
255
  ):
236
256
  logging.info("PASS")
237
257
  else:
238
- logging.info("FAILED")
258
+ logging.error("FAILED")
239
259
 
240
260
  for prompts in generate_prompts:
241
261
  logging.info("Verifying the reauthored model with prompts:%s", prompts)
242
262
  if verify_model_with_prompts(
243
- original_model, reauthored_model, tokenizer, prompts
263
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
244
264
  ):
245
265
  logging.info("PASS")
246
266
  else:
247
- logging.info("FAILED")
267
+ logging.error("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240926"
16
+ __version__ = "0.3.0.dev20240927"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240926
3
+ Version: 0.3.0.dev20240927
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI