ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240929__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/qwen/__init__.py +14 -0
  18. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +81 -0
  19. ai_edge_torch/generative/examples/qwen/qwen.py +141 -0
  20. ai_edge_torch/generative/examples/qwen/verify.py +88 -0
  21. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  22. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  23. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  24. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  25. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  26. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  27. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  28. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  29. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  30. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  31. ai_edge_torch/generative/layers/model_config.py +2 -0
  32. ai_edge_torch/generative/test/test_model_conversion_large.py +20 -0
  33. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  34. ai_edge_torch/generative/utilities/verifier.py +117 -97
  35. ai_edge_torch/version.py +1 -1
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/METADATA +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/RECORD +40 -29
  38. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/LICENSE +0 -0
  39. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/WHEEL +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240929.dist-info}/top_level.txt +0 -0
@@ -44,13 +44,10 @@ class ToySingleLayerModel(torch.nn.Module):
44
44
  self.rope_cache = attn_utils.build_rope_cache(
45
45
  size=config.max_seq_len,
46
46
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
47
- base=10_000,
48
- condense_ratio=1,
49
- dtype=torch.float32,
50
- device=torch.device('cpu'),
47
+ base=attn_config.rotary_base,
51
48
  )
52
49
  self.mask_cache = attn_utils.build_causal_mask_cache(
53
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
50
+ size=config.max_seq_len,
54
51
  )
55
52
  self.config = config
56
53
 
@@ -93,13 +90,10 @@ class ToySingleLayerModelWeightSharing(torch.nn.Module):
93
90
  self.rope_cache = attn_utils.build_rope_cache(
94
91
  size=config.max_seq_len,
95
92
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
96
- base=10_000,
97
- condense_ratio=1,
98
- dtype=torch.float32,
99
- device=torch.device('cpu'),
93
+ base=attn_config.rotary_base,
100
94
  )
101
95
  self.mask_cache = attn_utils.build_causal_mask_cache(
102
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
96
+ size=config.max_seq_len,
103
97
  )
104
98
  self.config = config
105
99
 
@@ -124,6 +118,7 @@ def get_model_config() -> cfg.ModelConfig:
124
118
  num_heads=32,
125
119
  head_dim=4,
126
120
  num_query_groups=4,
121
+ rotary_base=10000,
127
122
  rotary_percentage=1.0,
128
123
  enable_kv_cache=False,
129
124
  )
@@ -51,13 +51,10 @@ class ToyModelWithKVCache(torch.nn.Module):
51
51
  self.rope_cache = attn_utils.build_rope_cache(
52
52
  size=config.max_seq_len,
53
53
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
54
- base=10_000,
55
- condense_ratio=1,
56
- dtype=torch.float32,
57
- device=torch.device('cpu'),
54
+ base=attn_config.rotary_base,
58
55
  )
59
56
  self.mask_cache = attn_utils.build_causal_mask_cache(
60
- size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
57
+ size=config.max_seq_len,
61
58
  )
62
59
  self.config = config
63
60
 
@@ -91,6 +88,7 @@ def get_model_config() -> cfg.ModelConfig:
91
88
  num_heads=32,
92
89
  head_dim=4,
93
90
  num_query_groups=4,
91
+ rotary_base=10000,
94
92
  rotary_percentage=1.0,
95
93
  )
96
94
  ff_config = cfg.FeedForwardConfig(
@@ -67,15 +67,10 @@ class TinyLlama(nn.Module):
67
67
  self.rope_cache = attn_utils.build_rope_cache(
68
68
  size=config.kv_cache_max,
69
69
  dim=int(attn_config.rotary_percentage * attn_config.head_dim),
70
- base=10_000,
71
- condense_ratio=1,
72
- dtype=torch.float32,
73
- device=torch.device("cpu"),
70
+ base=attn_config.rotary_base,
74
71
  )
75
72
  self.mask_cache = attn_utils.build_causal_mask_cache(
76
73
  size=config.kv_cache_max,
77
- dtype=torch.float32,
78
- device=torch.device("cpu"),
79
74
  )
80
75
  self.config = config
81
76
 
@@ -132,6 +127,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
132
127
  num_heads=32,
133
128
  head_dim=64,
134
129
  num_query_groups=4,
130
+ rotary_base=10000,
135
131
  rotary_percentage=1.0,
136
132
  )
137
133
  ff_config = cfg.FeedForwardConfig(
@@ -21,6 +21,7 @@ import pathlib
21
21
  from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.generative.utilities import transformers_verifier
24
25
  from ai_edge_torch.generative.utilities import verifier
25
26
  import transformers
26
27
 
@@ -30,16 +31,20 @@ _PROMPTS = flags.DEFINE_multi_string(
30
31
  "Show me the program to add 2 and 3.",
31
32
  "The input prompts to generate answers.",
32
33
  )
34
+ _MAX_NEW_TOKENS = flags.DEFINE_integer(
35
+ "max_new_tokens",
36
+ 30,
37
+ "The maximum size of the generated tokens.",
38
+ )
33
39
 
34
40
 
35
41
  def main(_):
36
42
  checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
37
43
  logging.info("Loading the original model from: %s", checkpoint)
38
- wrapper_model = verifier.ModelWrapper(
39
- model=transformers.AutoModelForCausalLM.from_pretrained(
40
- checkpoint, trust_remote_code=True
41
- ),
44
+ original_model = transformers.AutoModelForCausalLM.from_pretrained(
45
+ checkpoint, trust_remote_code=True
42
46
  )
47
+
43
48
  # Locate the cached dir.
44
49
  cached_config_file = transformers.utils.cached_file(
45
50
  checkpoint, transformers.utils.CONFIG_NAME
@@ -52,10 +57,13 @@ def main(_):
52
57
  tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
53
58
 
54
59
  verifier.verify_reauthored_model(
55
- original_model=wrapper_model,
56
- reauthored_model=reauthored_model,
57
- tokenizer=tokenizer,
60
+ original_model=transformers_verifier.TransformersModelWrapper(
61
+ original_model
62
+ ),
63
+ reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
64
+ tokenizer=verifier.TokenizerWrapper(tokenizer),
58
65
  generate_prompts=_PROMPTS.value,
66
+ max_new_tokens=_MAX_NEW_TOKENS.value,
59
67
  atol=1e-04,
60
68
  )
61
69
 
@@ -83,6 +83,8 @@ class AttentionConfig:
83
83
  # Used to determine number of groups in grouped query attention (GQA)
84
84
  # https://arxiv.org/pdf/2305.13245.pdf
85
85
  num_query_groups: Optional[int]
86
+ # Base of rotary positional embedding.
87
+ rotary_base: int = 10_000
86
88
  # Percentage of Rotary Positional Embedding added Q and K projections.
87
89
  rotary_percentage: Optional[float] = None
88
90
  # Whether to transpose the query groups of qkv bundled tensor before
@@ -19,9 +19,11 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
24
25
  from ai_edge_torch.generative.examples.phi import phi3
26
+ from ai_edge_torch.generative.examples.qwen import qwen
25
27
  from ai_edge_torch.generative.examples.smollm import smollm
26
28
  from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
27
29
  from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
@@ -102,6 +104,15 @@ class TestModelConversion(googletest.TestCase):
102
104
  pytorch_model = gemma2.Gemma2(config).eval()
103
105
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
104
106
 
107
+ @googletest.skipIf(
108
+ ai_edge_config.Config.use_torch_xla,
109
+ reason="tests with custom ops are not supported on oss",
110
+ )
111
+ def test_llama(self):
112
+ config = llama.get_fake_model_config()
113
+ pytorch_model = llama.Llama(config).eval()
114
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
115
+
105
116
  @googletest.skipIf(
106
117
  ai_edge_config.Config.use_torch_xla,
107
118
  reason="tests with custom ops are not supported on oss",
@@ -142,6 +153,15 @@ class TestModelConversion(googletest.TestCase):
142
153
  pytorch_model = openelm.OpenELM(config).eval()
143
154
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
144
155
 
156
+ @googletest.skipIf(
157
+ ai_edge_config.Config.use_torch_xla,
158
+ reason="tests with custom ops are not supported on oss",
159
+ )
160
+ def test_qwen(self):
161
+ config = qwen.get_fake_model_config()
162
+ pytorch_model = qwen.Qwen(config).eval()
163
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
164
+
145
165
  @googletest.skipIf(
146
166
  ai_edge_config.Config.use_torch_xla,
147
167
  reason="tests with custom ops are not supported on oss",
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)
@@ -16,111 +16,129 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import List, Optional, Union
19
+ from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import torch
23
- import transformers
24
23
 
25
24
 
26
25
  class ModelWrapper(torch.nn.Module):
27
- """A wrapper for the model to be verified, this could be a HuggingFace model
26
+ """A wrapper for the model to be verified.
28
27
 
29
- or a regular PyTorch model.
28
+ It unifies the interface of forward() and generate() of models for the
29
+ verification to call.
30
30
  """
31
31
 
32
- def __init__(
33
- self,
34
- model: torch.nn.Module,
35
- model_format: str = "huggingface",
36
- hf_generation_config: Optional[transformers.GenerationConfig] = None,
37
- ):
32
+ def __init__(self, model: torch.nn.Module):
38
33
  """Initializes the wrapper.
39
34
 
40
35
  Args:
41
- model (torch.nn.Module): The original model. This could be a model built
42
- from HuggingFace transformers, or a regular PyTorch model.
43
- model_format (str): The format of the model. It should be either
44
- "huggingface" or "pytorch".
45
- hf_generation_config (transformers.GenerationConfig): The HuggingFace
46
- generation config. This config will only be used if the underlying model
47
- is built from HuggingFace transformers.
36
+ model (torch.nn.Module): The model which might have different interfaces
37
+ of forward() and generate(). It could be a model built from HuggingFace
38
+ transformers, a regular PyTorch model, or a model re-authored with
39
+ ai_edge_torch Generative API.
48
40
  """
49
41
  super().__init__()
50
42
  self.model = model
51
- self.model_format = model_format
52
- self.hf_generation_config = hf_generation_config
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ """Gets output logits by forwarding the input tokens.
46
+
47
+ Args:
48
+ tokens (torch.Tensor): The input tokens to forward. Its dimension is
49
+ expected to be (batch_size=1, kv_cache_max_len).
50
+
51
+ Returns:
52
+ The output logits.
53
+ """
54
+ raise NotImplementedError("forward() is not implemented.")
53
55
 
54
56
  def generate(
55
- self, inputs: torch.Tensor
56
- ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
57
- if self.model_format == "huggingface":
58
- return self.model.generate(
59
- inputs=inputs, generation_config=self.hf_generation_config
60
- )
61
- else:
62
- raise NotImplementedError(
63
- "generate() is not implemented for model format: %s"
64
- % self.model_format
65
- )
57
+ self, prompts: torch.Tensor, max_new_tokens: int
58
+ ) -> torch.IntTensor:
59
+ """Returns the response token IDs to the given prompts tensor.
66
60
 
67
- def forward(
68
- self,
69
- inputs: torch.Tensor,
70
- ):
71
- return self.model.forward(inputs)
61
+ The maximum number of tokens to generate might be set by subclasses.
72
62
 
63
+ Args:
64
+ prompts (torch.Tensor): The input token IDs to generate with. Its shape is
65
+ expected to be (batch_size=1, input_ids_len).
66
+ max_new_tokens (int): The maximum number of response token IDs to
67
+ generate.
68
+
69
+ Returns:
70
+ The tensor of response token IDs with shape of (batch_size=1,
71
+ response_ids_len).
72
+ """
73
+ raise NotImplementedError("generate() is not implemented.")
73
74
 
74
- def forward(
75
- model: torch.nn.Module,
76
- tokens: torch.Tensor,
77
- kv_cache: kv_utils.KVCache,
78
- ) -> tuple[torch.Tensor, kv_utils.KVCache]:
79
- """Forwards the model reauthored with ai_edge_torch Generative API.
80
75
 
81
- Args:
82
- model (torch.nn.Module): The model to forward. It should be a model built
83
- with ai_edge_torch Generative API.
84
- tokens (torch.Tensor): The input tokens to forward.
85
- kv_cache (KVCache): The KV cache to forward.
76
+ class ReauthoredModelWrapper(ModelWrapper):
77
+ """A wrapper for the model reauthored with ai_edge_torch Generative API."""
86
78
 
87
- Returns:
88
- The output logits and the updated KV cache.
89
- """
90
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
91
- output = model.forward(tokens, input_pos, kv_cache)
92
- return output["logits"], output["kv_cache"]
79
+ def _init_kv_cache(self):
80
+ """Returns an initialized KV cache."""
81
+ return kv_utils.KVCache.from_model_config(self.model.config)
93
82
 
83
+ def _forward_with_kv_cache(
84
+ self,
85
+ tokens: torch.Tensor,
86
+ kv_cache: kv_utils.KVCache,
87
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
+ """Forwards the model and updates an external KV cache.
94
89
 
95
- def generate(
96
- model: torch.nn.Module, prompts: torch.Tensor, response_len: int
97
- ) -> torch.Tensor:
98
- """Generates the response to the prompts.
90
+ Args:
91
+ tokens (torch.Tensor): The input tokens to forward.
92
+ kv_cache (KVCache): The KV cache to forward.
99
93
 
100
- It appends tokens output by the model to the prompts and feeds them back to
101
- the model up to decode_len.
94
+ Returns:
95
+ The output logits and the updated KV cache.
96
+ """
97
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
+ output = self.model.forward(tokens, input_pos, kv_cache)
99
+ return output["logits"], output["kv_cache"]
102
100
 
103
- Args:
104
- model (torch.nn.Module): The model to generate. It should be a model built
105
- with ai_edge_torch Generative API.
106
- prompts (torch.Tensor): The prompts to generate.
107
- response_len (int): The number of tokens to generate.
101
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
+ logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
103
+ return logits
108
104
 
109
- Returns:
110
- The generated tokens.
111
- """
112
- input_ids = prompts[0].int().tolist()
113
- kv_cache = kv_utils.KVCache.from_model_config(model.config)
114
- for _ in range(response_len - len(input_ids)):
115
- logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
116
- generated_token = logits[0][-1].argmax().item()
117
- input_ids.append(generated_token)
118
- return torch.tensor([input_ids])
105
+ def generate(
106
+ self, prompts: torch.Tensor, max_new_tokens: int
107
+ ) -> torch.IntTensor:
108
+ input_ids = prompts[0].int().tolist()
109
+ kv_cache = self._init_kv_cache()
110
+ for _ in range(max_new_tokens):
111
+ tokens = torch.tensor([input_ids])
112
+ logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
113
+ generated_token = logits[0][-1].argmax().item()
114
+ input_ids.append(generated_token)
115
+ return torch.tensor([input_ids])
116
+
117
+
118
+ class TokenizerWrapper(torch.nn.Module):
119
+ """A wrapper for the tokenizer used for verification."""
120
+
121
+ def __init__(self, tokenizer: torch.nn.Module):
122
+ """Initializes the wrapper.
123
+
124
+ Args:
125
+ tokenizer (torch.nn.Module): The tokenizer to wrap.
126
+ """
127
+ super().__init__()
128
+ self.tokenizer = tokenizer
129
+
130
+ def encode(self, prompts: str) -> torch.Tensor:
131
+ """Encodes the prompts to token IDs."""
132
+ return self.tokenizer.encode(prompts, return_tensors="pt")
133
+
134
+ def decode(self, token_ids: torch.Tensor) -> str:
135
+ """Decodes the token IDs to a string."""
136
+ return self.tokenizer.decode(token_ids)
119
137
 
120
138
 
121
139
  def verify_with_input_ids(
122
140
  original_model: ModelWrapper,
123
- reauthored_model: torch.nn.Module,
141
+ reauthored_model: ReauthoredModelWrapper,
124
142
  input_ids: List[int],
125
143
  kv_cache_max_len: int = 1024,
126
144
  rtol: float = 1e-05,
@@ -132,8 +150,8 @@ def verify_with_input_ids(
132
150
 
133
151
  Args:
134
152
  original_model (ModelWrapper): The original model.
135
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
136
- Generative API.
153
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
154
+ ai_edge_torch Generative API.
137
155
  input_ids (List[int]): The input token IDs to forward with.
138
156
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
139
157
  rtol (float): The relative tolerance for the comparison.
@@ -147,13 +165,12 @@ def verify_with_input_ids(
147
165
 
148
166
  logging.info("Forwarding the original model...")
149
167
  outputs_original = original_model.forward(tokens)
150
- logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
168
+ logits_original = outputs_original[0, len(input_ids) - 1, :]
151
169
  logging.info("logits_original: %s", logits_original)
152
170
 
153
171
  logging.info("Forwarding the reauthored model...")
154
- kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
155
- outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
156
- logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
172
+ outputs_reauthored = reauthored_model.forward(tokens)
173
+ logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
157
174
  logging.info("logits_reauthored: %s", logits_reauthored)
158
175
 
159
176
  return torch.allclose(
@@ -163,9 +180,10 @@ def verify_with_input_ids(
163
180
 
164
181
  def verify_model_with_prompts(
165
182
  original_model: ModelWrapper,
166
- reauthored_model: torch.nn.Module,
167
- tokenizer: torch.nn.Module,
183
+ reauthored_model: ReauthoredModelWrapper,
184
+ tokenizer: TokenizerWrapper,
168
185
  prompts: str,
186
+ max_new_tokens: int,
169
187
  ) -> bool:
170
188
  """Verifies if the model reauthored generates the same answer of the oringal.
171
189
 
@@ -174,24 +192,24 @@ def verify_model_with_prompts(
174
192
 
175
193
  Args:
176
194
  original_model (ModelWrapper): The original model.
177
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
178
- Generative API.
179
- tokenizer (torch.nn.Module): The tokenizer.
195
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
196
+ ai_edge_torch Generative API.
197
+ tokenizer (TokenizerWrapper): The tokenizer.
180
198
  prompts (str): The input prompts to generate answers.
199
+ max_new_tokens (int): The maximum number of new tokens to generate.
181
200
 
182
201
  Returns:
183
202
  True if the model reauthored generates the same answer of the original.
184
203
  """
185
- prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
204
+ prompt_tokens = tokenizer.encode(prompts)
186
205
 
187
206
  logging.info("Generating answer with the original model...")
188
- outputs_original = original_model.generate(prompt_tokens)
207
+ outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
189
208
  response_original = tokenizer.decode(outputs_original[0])
190
209
  logging.info("outputs_from_original_model: [[%s]]", response_original)
191
210
 
192
211
  logging.info("Generating answer with the reauthored model...")
193
- generate_len = len(outputs_original[0])
194
- outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
212
+ outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
195
213
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
196
214
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
197
215
 
@@ -200,9 +218,10 @@ def verify_model_with_prompts(
200
218
 
201
219
  def verify_reauthored_model(
202
220
  original_model: ModelWrapper,
203
- reauthored_model: torch.nn.Module,
204
- tokenizer: torch.nn.Module,
221
+ reauthored_model: ReauthoredModelWrapper,
222
+ tokenizer: TokenizerWrapper,
205
223
  generate_prompts: List[str],
224
+ max_new_tokens: int = 30,
206
225
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
207
226
  rtol: float = 1e-05,
208
227
  atol: float = 1e-05,
@@ -219,10 +238,11 @@ def verify_reauthored_model(
219
238
 
220
239
  Args:
221
240
  original_model (ModelWrapper): The original model.
222
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
223
- Generative API.
224
- tokenizer (torch.nn.Module): The tokenizer.
241
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
242
+ ai_edge_torch Generative API.
243
+ tokenizer (TokenizerWrapper): The tokenizer.
225
244
  generate_prompts (List[str]): List of the input prompts to generate answers.
245
+ max_new_tokens (int): The maximum number of new tokens to generate.
226
246
  forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
227
247
  forward with.
228
248
  rtol (float): The relative tolerance for the comparison.
@@ -235,13 +255,13 @@ def verify_reauthored_model(
235
255
  ):
236
256
  logging.info("PASS")
237
257
  else:
238
- logging.info("FAILED")
258
+ logging.error("FAILED")
239
259
 
240
260
  for prompts in generate_prompts:
241
261
  logging.info("Verifying the reauthored model with prompts:%s", prompts)
242
262
  if verify_model_with_prompts(
243
- original_model, reauthored_model, tokenizer, prompts
263
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
244
264
  ):
245
265
  logging.info("PASS")
246
266
  else:
247
- logging.info("FAILED")
267
+ logging.error("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240926"
16
+ __version__ = "0.3.0.dev20240929"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240926
3
+ Version: 0.3.0.dev20240929
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI