ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
  3. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  6. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  7. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  10. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  11. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  12. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  13. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  16. ai_edge_torch/generative/examples/phi/phi3.py +279 -0
  17. ai_edge_torch/generative/examples/phi/verify.py +13 -13
  18. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  19. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  20. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  21. ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
  22. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
  23. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
  24. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  25. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  29. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  30. ai_edge_torch/generative/layers/model_config.py +2 -0
  31. ai_edge_torch/generative/layers/normalization.py +2 -2
  32. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  33. ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
  34. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  35. ai_edge_torch/generative/utilities/verifier.py +130 -114
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
  39. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -83,6 +83,8 @@ class AttentionConfig:
83
83
  # Used to determine number of groups in grouped query attention (GQA)
84
84
  # https://arxiv.org/pdf/2305.13245.pdf
85
85
  num_query_groups: Optional[int]
86
+ # Base of rotary positional embedding.
87
+ rotary_base: int = 10_000
86
88
  # Percentage of Rotary Positional Embedding added Q and K projections.
87
89
  rotary_percentage: Optional[float] = None
88
90
  # Whether to transpose the query groups of qkv bundled tensor before
@@ -189,7 +189,7 @@ def group_norm_with_hlfb(
189
189
  name="odml.group_norm",
190
190
  attr={
191
191
  "num_groups": num_groups,
192
- "eps": eps,
192
+ "epsilon": eps,
193
193
  "reduction_axes": 3,
194
194
  "channel_axis": 3,
195
195
  },
@@ -226,7 +226,7 @@ def layer_norm_with_hlfb(
226
226
  """
227
227
  builder = StableHLOCompositeBuilder(
228
228
  name="odml.group_norm",
229
- attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
229
+ attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
230
  )
231
231
  x, w, b = builder.mark_inputs(x, w, b)
232
232
  if use_input_shape:
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  from ai_edge_torch.generative.layers.attention import CrossAttention
19
19
  from ai_edge_torch.generative.layers.attention import SelfAttention
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
416
416
  time_emb: Optional[torch.Tensor] = None,
417
417
  context_tensor: Optional[torch.Tensor] = None,
418
418
  output_hidden_states: bool = False,
419
- ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
419
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
420
420
  """Forward function of the DownEncoderBlock2D.
421
421
 
422
422
  Args:
@@ -19,9 +19,14 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
25
+ from ai_edge_torch.generative.examples.phi import phi3
24
26
  from ai_edge_torch.generative.examples.smollm import smollm
27
+ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
28
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
29
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
25
30
  from ai_edge_torch.generative.layers import kv_cache
26
31
  from ai_edge_torch.generative.test import utils as test_utils
27
32
  import numpy as np
@@ -98,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
98
103
  pytorch_model = gemma2.Gemma2(config).eval()
99
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
100
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
101
115
  @googletest.skipIf(
102
116
  ai_edge_config.Config.use_torch_xla,
103
117
  reason="tests with custom ops are not supported on oss",
@@ -109,6 +123,17 @@ class TestModelConversion(googletest.TestCase):
109
123
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
110
124
  )
111
125
 
126
+ @googletest.skipIf(
127
+ ai_edge_config.Config.use_torch_xla,
128
+ reason="tests with custom ops are not supported on oss",
129
+ )
130
+ def test_phi3(self):
131
+ config = phi3.get_fake_model_config()
132
+ pytorch_model = phi3.Phi3_5Mini(config).eval()
133
+ self._test_model(
134
+ config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
135
+ )
136
+
112
137
  @googletest.skipIf(
113
138
  ai_edge_config.Config.use_torch_xla,
114
139
  reason="tests with custom ops are not supported on oss",
@@ -127,6 +152,110 @@ class TestModelConversion(googletest.TestCase):
127
152
  pytorch_model = openelm.OpenELM(config).eval()
128
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
129
154
 
155
+ @googletest.skipIf(
156
+ ai_edge_config.Config.use_torch_xla,
157
+ reason="tests with custom ops are not supported on oss",
158
+ )
159
+ def test_stable_diffusion_clip(self):
160
+ config = sd_clip.get_fake_model_config()
161
+ prompt_tokens = torch.from_numpy(
162
+ np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
163
+ )
164
+
165
+ pytorch_model = sd_clip.CLIP(config).eval()
166
+ torch_output = pytorch_model(prompt_tokens)
167
+
168
+ edge_model = ai_edge_torch.signature(
169
+ "encode", pytorch_model, (prompt_tokens,)
170
+ ).convert()
171
+ edge_model.set_interpreter_builder(
172
+ self._interpreter_builder(edge_model.tflite_model())
173
+ )
174
+ edge_output = edge_model(
175
+ prompt_tokens.numpy(),
176
+ signature_name="encode",
177
+ )
178
+ self.assertTrue(
179
+ np.allclose(
180
+ edge_output,
181
+ torch_output.detach().numpy(),
182
+ atol=1e-4,
183
+ rtol=1e-5,
184
+ )
185
+ )
186
+
187
+ @googletest.skipIf(
188
+ ai_edge_config.Config.use_torch_xla,
189
+ reason="tests with custom ops are not supported on oss",
190
+ )
191
+ def test_stable_diffusion_diffusion(self):
192
+ config = sd_diffusion.get_fake_model_config(2)
193
+ latents = torch.from_numpy(
194
+ np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
195
+ )
196
+ context = torch.from_numpy(
197
+ np.random.normal(size=(2, 4, 4)).astype(np.float32)
198
+ )
199
+ time_embedding = torch.from_numpy(
200
+ np.random.normal(size=(2, 2)).astype(np.float32)
201
+ )
202
+
203
+ pytorch_model = sd_diffusion.Diffusion(config).eval()
204
+ torch_output = pytorch_model(latents, context, time_embedding)
205
+
206
+ edge_model = ai_edge_torch.signature(
207
+ "diffusion", pytorch_model, (latents, context, time_embedding)
208
+ ).convert()
209
+ edge_model.set_interpreter_builder(
210
+ self._interpreter_builder(edge_model.tflite_model())
211
+ )
212
+ edge_output = edge_model(
213
+ latents.numpy(),
214
+ context.numpy(),
215
+ time_embedding.numpy(),
216
+ signature_name="diffusion",
217
+ )
218
+ self.assertTrue(
219
+ np.allclose(
220
+ edge_output,
221
+ torch_output.detach().numpy(),
222
+ atol=1e-4,
223
+ rtol=1e-5,
224
+ )
225
+ )
226
+
227
+ @googletest.skipIf(
228
+ ai_edge_config.Config.use_torch_xla,
229
+ reason="tests with custom ops are not supported on oss",
230
+ )
231
+ def test_stable_diffusion_decoder(self):
232
+ config = sd_decoder.get_fake_model_config()
233
+ latents = torch.from_numpy(
234
+ np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
235
+ )
236
+
237
+ pytorch_model = sd_decoder.Decoder(config).eval()
238
+ torch_output = pytorch_model(latents)
239
+
240
+ edge_model = ai_edge_torch.signature(
241
+ "decode", pytorch_model, (latents,)
242
+ ).convert()
243
+ edge_model.set_interpreter_builder(
244
+ self._interpreter_builder(edge_model.tflite_model())
245
+ )
246
+ edge_output = edge_model(
247
+ latents.numpy(),
248
+ signature_name="decode",
249
+ )
250
+ self.assertTrue(
251
+ np.allclose(
252
+ edge_output,
253
+ torch_output.detach().numpy(),
254
+ atol=1e-4,
255
+ rtol=1e-5,
256
+ )
257
+ )
258
+
130
259
 
131
260
  if __name__ == "__main__":
132
261
  googletest.main()
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)
@@ -15,116 +15,130 @@
15
15
 
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
- import datetime
19
- from typing import List, Optional, Union
18
+ import logging
19
+ from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import torch
23
- import transformers
24
-
25
-
26
- def log_msg(*args):
27
- print("[%s]" % datetime.datetime.now(), *args)
28
23
 
29
24
 
30
25
  class ModelWrapper(torch.nn.Module):
31
- """A wrapper for the model to be verified, this could be a HuggingFace model
26
+ """A wrapper for the model to be verified.
32
27
 
33
- or a regular PyTorch model.
28
+ It unifies the interface of forward() and generate() of models for the
29
+ verification to call.
34
30
  """
35
31
 
36
- def __init__(
37
- self,
38
- model: torch.nn.Module,
39
- model_format: str = "huggingface",
40
- hf_generation_config: Optional[transformers.GenerationConfig] = None,
41
- ):
32
+ def __init__(self, model: torch.nn.Module):
42
33
  """Initializes the wrapper.
43
34
 
44
35
  Args:
45
- model (torch.nn.Module): The original model. This could be a model built
46
- from HuggingFace transformers, or a regular PyTorch model.
47
- model_format (str): The format of the model. It should be either
48
- "huggingface" or "pytorch".
49
- hf_generation_config (transformers.GenerationConfig): The HuggingFace
50
- generation config. This config will only be used if the underlying model
51
- is built from HuggingFace transformers.
36
+ model (torch.nn.Module): The model which might have different interfaces
37
+ of forward() and generate(). It could be a model built from HuggingFace
38
+ transformers, a regular PyTorch model, or a model re-authored with
39
+ ai_edge_torch Generative API.
52
40
  """
53
41
  super().__init__()
54
42
  self.model = model
55
- self.model_format = model_format
56
- self.hf_generation_config = hf_generation_config
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ """Gets output logits by forwarding the input tokens.
46
+
47
+ Args:
48
+ tokens (torch.Tensor): The input tokens to forward. Its dimension is
49
+ expected to be (batch_size=1, kv_cache_max_len).
50
+
51
+ Returns:
52
+ The output logits.
53
+ """
54
+ raise NotImplementedError("forward() is not implemented.")
57
55
 
58
56
  def generate(
59
- self, inputs: torch.Tensor
60
- ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
61
- if self.model_format == "huggingface":
62
- return self.model.generate(
63
- inputs=inputs, generation_config=self.hf_generation_config
64
- )
65
- else:
66
- raise NotImplementedError(
67
- "generate() is not implemented for model format: %s"
68
- % self.model_format
69
- )
57
+ self, prompts: torch.Tensor, max_new_tokens: int
58
+ ) -> torch.IntTensor:
59
+ """Returns the response token IDs to the given prompts tensor.
70
60
 
71
- def forward(
72
- self,
73
- inputs: torch.Tensor,
74
- ):
75
- return self.model.forward(inputs)
61
+ The maximum number of tokens to generate might be set by subclasses.
76
62
 
63
+ Args:
64
+ prompts (torch.Tensor): The input token IDs to generate with. Its shape is
65
+ expected to be (batch_size=1, input_ids_len).
66
+ max_new_tokens (int): The maximum number of response token IDs to
67
+ generate.
68
+
69
+ Returns:
70
+ The tensor of response token IDs with shape of (batch_size=1,
71
+ response_ids_len).
72
+ """
73
+ raise NotImplementedError("generate() is not implemented.")
77
74
 
78
- def forward(
79
- model: torch.nn.Module,
80
- tokens: torch.Tensor,
81
- kv_cache: kv_utils.KVCache,
82
- ) -> tuple[torch.Tensor, kv_utils.KVCache]:
83
- """Forwards the model reauthored with ai_edge_torch Generative API.
84
75
 
85
- Args:
86
- model (torch.nn.Module): The model to forward. It should be a model built
87
- with ai_edge_torch Generative API.
88
- tokens (torch.Tensor): The input tokens to forward.
89
- kv_cache (KVCache): The KV cache to forward.
76
+ class ReauthoredModelWrapper(ModelWrapper):
77
+ """A wrapper for the model reauthored with ai_edge_torch Generative API."""
90
78
 
91
- Returns:
92
- The output logits and the updated KV cache.
93
- """
94
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
95
- output = model.forward(tokens, input_pos, kv_cache)
96
- return output["logits"], output["kv_cache"]
79
+ def _init_kv_cache(self):
80
+ """Returns an initialized KV cache."""
81
+ return kv_utils.KVCache.from_model_config(self.model.config)
97
82
 
83
+ def _forward_with_kv_cache(
84
+ self,
85
+ tokens: torch.Tensor,
86
+ kv_cache: kv_utils.KVCache,
87
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
+ """Forwards the model and updates an external KV cache.
98
89
 
99
- def generate(
100
- model: torch.nn.Module, prompts: torch.Tensor, response_len: int
101
- ) -> torch.Tensor:
102
- """Generates the response to the prompts.
90
+ Args:
91
+ tokens (torch.Tensor): The input tokens to forward.
92
+ kv_cache (KVCache): The KV cache to forward.
103
93
 
104
- It appends tokens output by the model to the prompts and feeds them back to
105
- the model up to decode_len.
94
+ Returns:
95
+ The output logits and the updated KV cache.
96
+ """
97
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
+ output = self.model.forward(tokens, input_pos, kv_cache)
99
+ return output["logits"], output["kv_cache"]
106
100
 
107
- Args:
108
- model (torch.nn.Module): The model to generate. It should be a model built
109
- with ai_edge_torch Generative API.
110
- prompts (torch.Tensor): The prompts to generate.
111
- response_len (int): The number of tokens to generate.
101
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
+ logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
103
+ return logits
112
104
 
113
- Returns:
114
- The generated tokens.
115
- """
116
- input_ids = prompts[0].int().tolist()
117
- kv_cache = kv_utils.KVCache.from_model_config(model.config)
118
- for _ in range(response_len - len(input_ids)):
119
- logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
120
- generated_token = logits[0][-1].argmax().item()
121
- input_ids.append(generated_token)
122
- return torch.tensor([input_ids])
105
+ def generate(
106
+ self, prompts: torch.Tensor, max_new_tokens: int
107
+ ) -> torch.IntTensor:
108
+ input_ids = prompts[0].int().tolist()
109
+ kv_cache = self._init_kv_cache()
110
+ for _ in range(max_new_tokens):
111
+ tokens = torch.tensor([input_ids])
112
+ logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
113
+ generated_token = logits[0][-1].argmax().item()
114
+ input_ids.append(generated_token)
115
+ return torch.tensor([input_ids])
116
+
117
+
118
+ class TokenizerWrapper(torch.nn.Module):
119
+ """A wrapper for the tokenizer used for verification."""
120
+
121
+ def __init__(self, tokenizer: torch.nn.Module):
122
+ """Initializes the wrapper.
123
+
124
+ Args:
125
+ tokenizer (torch.nn.Module): The tokenizer to wrap.
126
+ """
127
+ super().__init__()
128
+ self.tokenizer = tokenizer
129
+
130
+ def encode(self, prompts: str) -> torch.Tensor:
131
+ """Encodes the prompts to token IDs."""
132
+ return self.tokenizer.encode(prompts, return_tensors="pt")
133
+
134
+ def decode(self, token_ids: torch.Tensor) -> str:
135
+ """Decodes the token IDs to a string."""
136
+ return self.tokenizer.decode(token_ids)
123
137
 
124
138
 
125
139
  def verify_with_input_ids(
126
140
  original_model: ModelWrapper,
127
- reauthored_model: torch.nn.Module,
141
+ reauthored_model: ReauthoredModelWrapper,
128
142
  input_ids: List[int],
129
143
  kv_cache_max_len: int = 1024,
130
144
  rtol: float = 1e-05,
@@ -136,8 +150,8 @@ def verify_with_input_ids(
136
150
 
137
151
  Args:
138
152
  original_model (ModelWrapper): The original model.
139
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
140
- Generative API.
153
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
154
+ ai_edge_torch Generative API.
141
155
  input_ids (List[int]): The input token IDs to forward with.
142
156
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
143
157
  rtol (float): The relative tolerance for the comparison.
@@ -149,16 +163,15 @@ def verify_with_input_ids(
149
163
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
150
164
  tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
151
165
 
152
- log_msg("Forwarding the original model...")
166
+ logging.info("Forwarding the original model...")
153
167
  outputs_original = original_model.forward(tokens)
154
- logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
155
- log_msg("logits_original: ", logits_original)
168
+ logits_original = outputs_original[0, len(input_ids) - 1, :]
169
+ logging.info("logits_original: %s", logits_original)
156
170
 
157
- log_msg("Forwarding the reauthored model...")
158
- kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
159
- outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
160
- logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
161
- log_msg("logits_reauthored:", logits_reauthored)
171
+ logging.info("Forwarding the reauthored model...")
172
+ outputs_reauthored = reauthored_model.forward(tokens)
173
+ logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
174
+ logging.info("logits_reauthored: %s", logits_reauthored)
162
175
 
163
176
  return torch.allclose(
164
177
  logits_original, logits_reauthored, rtol=rtol, atol=atol
@@ -167,9 +180,10 @@ def verify_with_input_ids(
167
180
 
168
181
  def verify_model_with_prompts(
169
182
  original_model: ModelWrapper,
170
- reauthored_model: torch.nn.Module,
171
- tokenizer: torch.nn.Module,
183
+ reauthored_model: ReauthoredModelWrapper,
184
+ tokenizer: TokenizerWrapper,
172
185
  prompts: str,
186
+ max_new_tokens: int,
173
187
  ) -> bool:
174
188
  """Verifies if the model reauthored generates the same answer of the oringal.
175
189
 
@@ -178,35 +192,36 @@ def verify_model_with_prompts(
178
192
 
179
193
  Args:
180
194
  original_model (ModelWrapper): The original model.
181
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
182
- Generative API.
183
- tokenizer (torch.nn.Module): The tokenizer.
195
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
196
+ ai_edge_torch Generative API.
197
+ tokenizer (TokenizerWrapper): The tokenizer.
184
198
  prompts (str): The input prompts to generate answers.
199
+ max_new_tokens (int): The maximum number of new tokens to generate.
185
200
 
186
201
  Returns:
187
202
  True if the model reauthored generates the same answer of the original.
188
203
  """
189
- prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
204
+ prompt_tokens = tokenizer.encode(prompts)
190
205
 
191
- log_msg("Generating answer with the original model...")
192
- outputs_original = original_model.generate(prompt_tokens)
206
+ logging.info("Generating answer with the original model...")
207
+ outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
193
208
  response_original = tokenizer.decode(outputs_original[0])
194
- log_msg("outputs_from_original_model: [[", response_original, "]]")
209
+ logging.info("outputs_from_original_model: [[%s]]", response_original)
195
210
 
196
- log_msg("Generating answer with the reauthored model...")
197
- generate_len = len(outputs_original[0])
198
- outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
211
+ logging.info("Generating answer with the reauthored model...")
212
+ outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
199
213
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
200
- log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
214
+ logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
201
215
 
202
216
  return response_original == response_reauthored
203
217
 
204
218
 
205
219
  def verify_reauthored_model(
206
220
  original_model: ModelWrapper,
207
- reauthored_model: torch.nn.Module,
208
- tokenizer: torch.nn.Module,
221
+ reauthored_model: ReauthoredModelWrapper,
222
+ tokenizer: TokenizerWrapper,
209
223
  generate_prompts: List[str],
224
+ max_new_tokens: int = 30,
210
225
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
211
226
  rtol: float = 1e-05,
212
227
  atol: float = 1e-05,
@@ -223,29 +238,30 @@ def verify_reauthored_model(
223
238
 
224
239
  Args:
225
240
  original_model (ModelWrapper): The original model.
226
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
227
- Generative API.
228
- tokenizer (torch.nn.Module): The tokenizer.
241
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
242
+ ai_edge_torch Generative API.
243
+ tokenizer (TokenizerWrapper): The tokenizer.
229
244
  generate_prompts (List[str]): List of the input prompts to generate answers.
245
+ max_new_tokens (int): The maximum number of new tokens to generate.
230
246
  forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
231
247
  forward with.
232
248
  rtol (float): The relative tolerance for the comparison.
233
249
  atol (float): The absolute tolerance for the comparison.
234
250
  """
235
251
  for input_ids in forward_input_ids:
236
- log_msg("Verifying the reauthored model with input IDs:", input_ids)
252
+ logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
237
253
  if verify_with_input_ids(
238
254
  original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
239
255
  ):
240
- log_msg("PASS")
256
+ logging.info("PASS")
241
257
  else:
242
- log_msg("FAILED")
258
+ logging.error("FAILED")
243
259
 
244
260
  for prompts in generate_prompts:
245
- log_msg("Verifying the reauthored model with prompts:", prompts)
261
+ logging.info("Verifying the reauthored model with prompts:%s", prompts)
246
262
  if verify_model_with_prompts(
247
- original_model, reauthored_model, tokenizer, prompts
263
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
248
264
  ):
249
- log_msg("PASS")
265
+ logging.info("PASS")
250
266
  else:
251
- log_msg("FAILED")
267
+ logging.error("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240924"
16
+ __version__ = "0.3.0.dev20240928"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240924
3
+ Version: 0.3.0.dev20240928
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI