ai-edge-torch-nightly 0.3.0.dev20240924__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -10
  3. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  5. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  6. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  7. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  9. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  10. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  11. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  12. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  13. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  14. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  16. ai_edge_torch/generative/examples/phi/phi3.py +279 -0
  17. ai_edge_torch/generative/examples/phi/verify.py +13 -13
  18. ai_edge_torch/generative/examples/phi/verify_phi3.py +69 -0
  19. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  20. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  21. ai_edge_torch/generative/examples/stable_diffusion/clip.py +54 -1
  22. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +58 -0
  23. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +71 -1
  24. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  25. ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
  26. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -41
  27. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +5 -61
  28. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  29. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  30. ai_edge_torch/generative/layers/model_config.py +2 -0
  31. ai_edge_torch/generative/layers/normalization.py +2 -2
  32. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -2
  33. ai_edge_torch/generative/test/test_model_conversion_large.py +129 -0
  34. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  35. ai_edge_torch/generative/utilities/verifier.py +130 -114
  36. ai_edge_torch/version.py +1 -1
  37. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  38. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +41 -30
  39. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240924.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -83,6 +83,8 @@ class AttentionConfig:
83
83
  # Used to determine number of groups in grouped query attention (GQA)
84
84
  # https://arxiv.org/pdf/2305.13245.pdf
85
85
  num_query_groups: Optional[int]
86
+ # Base of rotary positional embedding.
87
+ rotary_base: int = 10_000
86
88
  # Percentage of Rotary Positional Embedding added Q and K projections.
87
89
  rotary_percentage: Optional[float] = None
88
90
  # Whether to transpose the query groups of qkv bundled tensor before
@@ -189,7 +189,7 @@ def group_norm_with_hlfb(
189
189
  name="odml.group_norm",
190
190
  attr={
191
191
  "num_groups": num_groups,
192
- "eps": eps,
192
+ "epsilon": eps,
193
193
  "reduction_axes": 3,
194
194
  "channel_axis": 3,
195
195
  },
@@ -226,7 +226,7 @@ def layer_norm_with_hlfb(
226
226
  """
227
227
  builder = StableHLOCompositeBuilder(
228
228
  name="odml.group_norm",
229
- attr={"num_groups": 1, "eps": eps, "channel_axis": 1},
229
+ attr={"num_groups": 1, "epsilon": eps, "channel_axis": 1},
230
230
  )
231
231
  x, w, b = builder.mark_inputs(x, w, b)
232
232
  if use_input_shape:
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- from typing import List, Optional, Tuple
16
+ from typing import List, Optional, Tuple, Union
17
17
 
18
18
  from ai_edge_torch.generative.layers.attention import CrossAttention
19
19
  from ai_edge_torch.generative.layers.attention import SelfAttention
@@ -416,7 +416,7 @@ class DownEncoderBlock2D(nn.Module):
416
416
  time_emb: Optional[torch.Tensor] = None,
417
417
  context_tensor: Optional[torch.Tensor] = None,
418
418
  output_hidden_states: bool = False,
419
- ) -> torch.Tensor | Tuple[torch.Tensor, List[torch.Tensor]]:
419
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
420
420
  """Forward function of the DownEncoderBlock2D.
421
421
 
422
422
  Args:
@@ -19,9 +19,14 @@ import ai_edge_torch
19
19
  from ai_edge_torch import config as ai_edge_config
20
20
  from ai_edge_torch.generative.examples.gemma import gemma1
21
21
  from ai_edge_torch.generative.examples.gemma import gemma2
22
+ from ai_edge_torch.generative.examples.llama import llama
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
24
  from ai_edge_torch.generative.examples.phi import phi2
25
+ from ai_edge_torch.generative.examples.phi import phi3
24
26
  from ai_edge_torch.generative.examples.smollm import smollm
27
+ from ai_edge_torch.generative.examples.stable_diffusion import clip as sd_clip
28
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder as sd_decoder
29
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion as sd_diffusion
25
30
  from ai_edge_torch.generative.layers import kv_cache
26
31
  from ai_edge_torch.generative.test import utils as test_utils
27
32
  import numpy as np
@@ -98,6 +103,15 @@ class TestModelConversion(googletest.TestCase):
98
103
  pytorch_model = gemma2.Gemma2(config).eval()
99
104
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
100
105
 
106
+ @googletest.skipIf(
107
+ ai_edge_config.Config.use_torch_xla,
108
+ reason="tests with custom ops are not supported on oss",
109
+ )
110
+ def test_llama(self):
111
+ config = llama.get_fake_model_config()
112
+ pytorch_model = llama.Llama(config).eval()
113
+ self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
114
+
101
115
  @googletest.skipIf(
102
116
  ai_edge_config.Config.use_torch_xla,
103
117
  reason="tests with custom ops are not supported on oss",
@@ -109,6 +123,17 @@ class TestModelConversion(googletest.TestCase):
109
123
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
110
124
  )
111
125
 
126
+ @googletest.skipIf(
127
+ ai_edge_config.Config.use_torch_xla,
128
+ reason="tests with custom ops are not supported on oss",
129
+ )
130
+ def test_phi3(self):
131
+ config = phi3.get_fake_model_config()
132
+ pytorch_model = phi3.Phi3_5Mini(config).eval()
133
+ self._test_model(
134
+ config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5
135
+ )
136
+
112
137
  @googletest.skipIf(
113
138
  ai_edge_config.Config.use_torch_xla,
114
139
  reason="tests with custom ops are not supported on oss",
@@ -127,6 +152,110 @@ class TestModelConversion(googletest.TestCase):
127
152
  pytorch_model = openelm.OpenELM(config).eval()
128
153
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
129
154
 
155
+ @googletest.skipIf(
156
+ ai_edge_config.Config.use_torch_xla,
157
+ reason="tests with custom ops are not supported on oss",
158
+ )
159
+ def test_stable_diffusion_clip(self):
160
+ config = sd_clip.get_fake_model_config()
161
+ prompt_tokens = torch.from_numpy(
162
+ np.array([[1, 2, 3, 4, 5, 6]], dtype=np.int32)
163
+ )
164
+
165
+ pytorch_model = sd_clip.CLIP(config).eval()
166
+ torch_output = pytorch_model(prompt_tokens)
167
+
168
+ edge_model = ai_edge_torch.signature(
169
+ "encode", pytorch_model, (prompt_tokens,)
170
+ ).convert()
171
+ edge_model.set_interpreter_builder(
172
+ self._interpreter_builder(edge_model.tflite_model())
173
+ )
174
+ edge_output = edge_model(
175
+ prompt_tokens.numpy(),
176
+ signature_name="encode",
177
+ )
178
+ self.assertTrue(
179
+ np.allclose(
180
+ edge_output,
181
+ torch_output.detach().numpy(),
182
+ atol=1e-4,
183
+ rtol=1e-5,
184
+ )
185
+ )
186
+
187
+ @googletest.skipIf(
188
+ ai_edge_config.Config.use_torch_xla,
189
+ reason="tests with custom ops are not supported on oss",
190
+ )
191
+ def test_stable_diffusion_diffusion(self):
192
+ config = sd_diffusion.get_fake_model_config(2)
193
+ latents = torch.from_numpy(
194
+ np.random.normal(size=(2, 4, 8, 8)).astype(np.float32)
195
+ )
196
+ context = torch.from_numpy(
197
+ np.random.normal(size=(2, 4, 4)).astype(np.float32)
198
+ )
199
+ time_embedding = torch.from_numpy(
200
+ np.random.normal(size=(2, 2)).astype(np.float32)
201
+ )
202
+
203
+ pytorch_model = sd_diffusion.Diffusion(config).eval()
204
+ torch_output = pytorch_model(latents, context, time_embedding)
205
+
206
+ edge_model = ai_edge_torch.signature(
207
+ "diffusion", pytorch_model, (latents, context, time_embedding)
208
+ ).convert()
209
+ edge_model.set_interpreter_builder(
210
+ self._interpreter_builder(edge_model.tflite_model())
211
+ )
212
+ edge_output = edge_model(
213
+ latents.numpy(),
214
+ context.numpy(),
215
+ time_embedding.numpy(),
216
+ signature_name="diffusion",
217
+ )
218
+ self.assertTrue(
219
+ np.allclose(
220
+ edge_output,
221
+ torch_output.detach().numpy(),
222
+ atol=1e-4,
223
+ rtol=1e-5,
224
+ )
225
+ )
226
+
227
+ @googletest.skipIf(
228
+ ai_edge_config.Config.use_torch_xla,
229
+ reason="tests with custom ops are not supported on oss",
230
+ )
231
+ def test_stable_diffusion_decoder(self):
232
+ config = sd_decoder.get_fake_model_config()
233
+ latents = torch.from_numpy(
234
+ np.random.normal(size=(1, 4, 64, 64)).astype(np.float32)
235
+ )
236
+
237
+ pytorch_model = sd_decoder.Decoder(config).eval()
238
+ torch_output = pytorch_model(latents)
239
+
240
+ edge_model = ai_edge_torch.signature(
241
+ "decode", pytorch_model, (latents,)
242
+ ).convert()
243
+ edge_model.set_interpreter_builder(
244
+ self._interpreter_builder(edge_model.tflite_model())
245
+ )
246
+ edge_output = edge_model(
247
+ latents.numpy(),
248
+ signature_name="decode",
249
+ )
250
+ self.assertTrue(
251
+ np.allclose(
252
+ edge_output,
253
+ torch_output.detach().numpy(),
254
+ atol=1e-4,
255
+ rtol=1e-5,
256
+ )
257
+ )
258
+
130
259
 
131
260
  if __name__ == "__main__":
132
261
  googletest.main()
@@ -0,0 +1,42 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for the models predefined in HuggingFace transformers."""
17
+
18
+ from typing import cast
19
+
20
+ from ai_edge_torch.generative.utilities import verifier
21
+ import torch
22
+ import transformers
23
+
24
+
25
+ class TransformersModelWrapper(verifier.ModelWrapper):
26
+ """A wrapper for the model predefined in HuggingFace transformers.
27
+
28
+ Verifier expects forward() to return logits while Transformers models return
29
+ an object with `logits` field.
30
+
31
+ Transformers models get `max_new_tokens` settings for generate() via
32
+ GenerationConfig.
33
+ """
34
+
35
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
36
+ return self.model.forward(tokens).logits
37
+
38
+ def generate(
39
+ self, inputs: torch.Tensor, max_new_tokens: int
40
+ ) -> torch.IntTensor:
41
+ gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=gen_config)
@@ -15,116 +15,130 @@
15
15
 
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
- import datetime
19
- from typing import List, Optional, Union
18
+ import logging
19
+ from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import torch
23
- import transformers
24
-
25
-
26
- def log_msg(*args):
27
- print("[%s]" % datetime.datetime.now(), *args)
28
23
 
29
24
 
30
25
  class ModelWrapper(torch.nn.Module):
31
- """A wrapper for the model to be verified, this could be a HuggingFace model
26
+ """A wrapper for the model to be verified.
32
27
 
33
- or a regular PyTorch model.
28
+ It unifies the interface of forward() and generate() of models for the
29
+ verification to call.
34
30
  """
35
31
 
36
- def __init__(
37
- self,
38
- model: torch.nn.Module,
39
- model_format: str = "huggingface",
40
- hf_generation_config: Optional[transformers.GenerationConfig] = None,
41
- ):
32
+ def __init__(self, model: torch.nn.Module):
42
33
  """Initializes the wrapper.
43
34
 
44
35
  Args:
45
- model (torch.nn.Module): The original model. This could be a model built
46
- from HuggingFace transformers, or a regular PyTorch model.
47
- model_format (str): The format of the model. It should be either
48
- "huggingface" or "pytorch".
49
- hf_generation_config (transformers.GenerationConfig): The HuggingFace
50
- generation config. This config will only be used if the underlying model
51
- is built from HuggingFace transformers.
36
+ model (torch.nn.Module): The model which might have different interfaces
37
+ of forward() and generate(). It could be a model built from HuggingFace
38
+ transformers, a regular PyTorch model, or a model re-authored with
39
+ ai_edge_torch Generative API.
52
40
  """
53
41
  super().__init__()
54
42
  self.model = model
55
- self.model_format = model_format
56
- self.hf_generation_config = hf_generation_config
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ """Gets output logits by forwarding the input tokens.
46
+
47
+ Args:
48
+ tokens (torch.Tensor): The input tokens to forward. Its dimension is
49
+ expected to be (batch_size=1, kv_cache_max_len).
50
+
51
+ Returns:
52
+ The output logits.
53
+ """
54
+ raise NotImplementedError("forward() is not implemented.")
57
55
 
58
56
  def generate(
59
- self, inputs: torch.Tensor
60
- ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
61
- if self.model_format == "huggingface":
62
- return self.model.generate(
63
- inputs=inputs, generation_config=self.hf_generation_config
64
- )
65
- else:
66
- raise NotImplementedError(
67
- "generate() is not implemented for model format: %s"
68
- % self.model_format
69
- )
57
+ self, prompts: torch.Tensor, max_new_tokens: int
58
+ ) -> torch.IntTensor:
59
+ """Returns the response token IDs to the given prompts tensor.
70
60
 
71
- def forward(
72
- self,
73
- inputs: torch.Tensor,
74
- ):
75
- return self.model.forward(inputs)
61
+ The maximum number of tokens to generate might be set by subclasses.
76
62
 
63
+ Args:
64
+ prompts (torch.Tensor): The input token IDs to generate with. Its shape is
65
+ expected to be (batch_size=1, input_ids_len).
66
+ max_new_tokens (int): The maximum number of response token IDs to
67
+ generate.
68
+
69
+ Returns:
70
+ The tensor of response token IDs with shape of (batch_size=1,
71
+ response_ids_len).
72
+ """
73
+ raise NotImplementedError("generate() is not implemented.")
77
74
 
78
- def forward(
79
- model: torch.nn.Module,
80
- tokens: torch.Tensor,
81
- kv_cache: kv_utils.KVCache,
82
- ) -> tuple[torch.Tensor, kv_utils.KVCache]:
83
- """Forwards the model reauthored with ai_edge_torch Generative API.
84
75
 
85
- Args:
86
- model (torch.nn.Module): The model to forward. It should be a model built
87
- with ai_edge_torch Generative API.
88
- tokens (torch.Tensor): The input tokens to forward.
89
- kv_cache (KVCache): The KV cache to forward.
76
+ class ReauthoredModelWrapper(ModelWrapper):
77
+ """A wrapper for the model reauthored with ai_edge_torch Generative API."""
90
78
 
91
- Returns:
92
- The output logits and the updated KV cache.
93
- """
94
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
95
- output = model.forward(tokens, input_pos, kv_cache)
96
- return output["logits"], output["kv_cache"]
79
+ def _init_kv_cache(self):
80
+ """Returns an initialized KV cache."""
81
+ return kv_utils.KVCache.from_model_config(self.model.config)
97
82
 
83
+ def _forward_with_kv_cache(
84
+ self,
85
+ tokens: torch.Tensor,
86
+ kv_cache: kv_utils.KVCache,
87
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
+ """Forwards the model and updates an external KV cache.
98
89
 
99
- def generate(
100
- model: torch.nn.Module, prompts: torch.Tensor, response_len: int
101
- ) -> torch.Tensor:
102
- """Generates the response to the prompts.
90
+ Args:
91
+ tokens (torch.Tensor): The input tokens to forward.
92
+ kv_cache (KVCache): The KV cache to forward.
103
93
 
104
- It appends tokens output by the model to the prompts and feeds them back to
105
- the model up to decode_len.
94
+ Returns:
95
+ The output logits and the updated KV cache.
96
+ """
97
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
+ output = self.model.forward(tokens, input_pos, kv_cache)
99
+ return output["logits"], output["kv_cache"]
106
100
 
107
- Args:
108
- model (torch.nn.Module): The model to generate. It should be a model built
109
- with ai_edge_torch Generative API.
110
- prompts (torch.Tensor): The prompts to generate.
111
- response_len (int): The number of tokens to generate.
101
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
+ logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
103
+ return logits
112
104
 
113
- Returns:
114
- The generated tokens.
115
- """
116
- input_ids = prompts[0].int().tolist()
117
- kv_cache = kv_utils.KVCache.from_model_config(model.config)
118
- for _ in range(response_len - len(input_ids)):
119
- logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
120
- generated_token = logits[0][-1].argmax().item()
121
- input_ids.append(generated_token)
122
- return torch.tensor([input_ids])
105
+ def generate(
106
+ self, prompts: torch.Tensor, max_new_tokens: int
107
+ ) -> torch.IntTensor:
108
+ input_ids = prompts[0].int().tolist()
109
+ kv_cache = self._init_kv_cache()
110
+ for _ in range(max_new_tokens):
111
+ tokens = torch.tensor([input_ids])
112
+ logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
113
+ generated_token = logits[0][-1].argmax().item()
114
+ input_ids.append(generated_token)
115
+ return torch.tensor([input_ids])
116
+
117
+
118
+ class TokenizerWrapper(torch.nn.Module):
119
+ """A wrapper for the tokenizer used for verification."""
120
+
121
+ def __init__(self, tokenizer: torch.nn.Module):
122
+ """Initializes the wrapper.
123
+
124
+ Args:
125
+ tokenizer (torch.nn.Module): The tokenizer to wrap.
126
+ """
127
+ super().__init__()
128
+ self.tokenizer = tokenizer
129
+
130
+ def encode(self, prompts: str) -> torch.Tensor:
131
+ """Encodes the prompts to token IDs."""
132
+ return self.tokenizer.encode(prompts, return_tensors="pt")
133
+
134
+ def decode(self, token_ids: torch.Tensor) -> str:
135
+ """Decodes the token IDs to a string."""
136
+ return self.tokenizer.decode(token_ids)
123
137
 
124
138
 
125
139
  def verify_with_input_ids(
126
140
  original_model: ModelWrapper,
127
- reauthored_model: torch.nn.Module,
141
+ reauthored_model: ReauthoredModelWrapper,
128
142
  input_ids: List[int],
129
143
  kv_cache_max_len: int = 1024,
130
144
  rtol: float = 1e-05,
@@ -136,8 +150,8 @@ def verify_with_input_ids(
136
150
 
137
151
  Args:
138
152
  original_model (ModelWrapper): The original model.
139
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
140
- Generative API.
153
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
154
+ ai_edge_torch Generative API.
141
155
  input_ids (List[int]): The input token IDs to forward with.
142
156
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
143
157
  rtol (float): The relative tolerance for the comparison.
@@ -149,16 +163,15 @@ def verify_with_input_ids(
149
163
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
150
164
  tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
151
165
 
152
- log_msg("Forwarding the original model...")
166
+ logging.info("Forwarding the original model...")
153
167
  outputs_original = original_model.forward(tokens)
154
- logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
155
- log_msg("logits_original: ", logits_original)
168
+ logits_original = outputs_original[0, len(input_ids) - 1, :]
169
+ logging.info("logits_original: %s", logits_original)
156
170
 
157
- log_msg("Forwarding the reauthored model...")
158
- kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
159
- outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
160
- logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
161
- log_msg("logits_reauthored:", logits_reauthored)
171
+ logging.info("Forwarding the reauthored model...")
172
+ outputs_reauthored = reauthored_model.forward(tokens)
173
+ logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
174
+ logging.info("logits_reauthored: %s", logits_reauthored)
162
175
 
163
176
  return torch.allclose(
164
177
  logits_original, logits_reauthored, rtol=rtol, atol=atol
@@ -167,9 +180,10 @@ def verify_with_input_ids(
167
180
 
168
181
  def verify_model_with_prompts(
169
182
  original_model: ModelWrapper,
170
- reauthored_model: torch.nn.Module,
171
- tokenizer: torch.nn.Module,
183
+ reauthored_model: ReauthoredModelWrapper,
184
+ tokenizer: TokenizerWrapper,
172
185
  prompts: str,
186
+ max_new_tokens: int,
173
187
  ) -> bool:
174
188
  """Verifies if the model reauthored generates the same answer of the oringal.
175
189
 
@@ -178,35 +192,36 @@ def verify_model_with_prompts(
178
192
 
179
193
  Args:
180
194
  original_model (ModelWrapper): The original model.
181
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
182
- Generative API.
183
- tokenizer (torch.nn.Module): The tokenizer.
195
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
196
+ ai_edge_torch Generative API.
197
+ tokenizer (TokenizerWrapper): The tokenizer.
184
198
  prompts (str): The input prompts to generate answers.
199
+ max_new_tokens (int): The maximum number of new tokens to generate.
185
200
 
186
201
  Returns:
187
202
  True if the model reauthored generates the same answer of the original.
188
203
  """
189
- prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
204
+ prompt_tokens = tokenizer.encode(prompts)
190
205
 
191
- log_msg("Generating answer with the original model...")
192
- outputs_original = original_model.generate(prompt_tokens)
206
+ logging.info("Generating answer with the original model...")
207
+ outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
193
208
  response_original = tokenizer.decode(outputs_original[0])
194
- log_msg("outputs_from_original_model: [[", response_original, "]]")
209
+ logging.info("outputs_from_original_model: [[%s]]", response_original)
195
210
 
196
- log_msg("Generating answer with the reauthored model...")
197
- generate_len = len(outputs_original[0])
198
- outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
211
+ logging.info("Generating answer with the reauthored model...")
212
+ outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
199
213
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
200
- log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
214
+ logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
201
215
 
202
216
  return response_original == response_reauthored
203
217
 
204
218
 
205
219
  def verify_reauthored_model(
206
220
  original_model: ModelWrapper,
207
- reauthored_model: torch.nn.Module,
208
- tokenizer: torch.nn.Module,
221
+ reauthored_model: ReauthoredModelWrapper,
222
+ tokenizer: TokenizerWrapper,
209
223
  generate_prompts: List[str],
224
+ max_new_tokens: int = 30,
210
225
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
211
226
  rtol: float = 1e-05,
212
227
  atol: float = 1e-05,
@@ -223,29 +238,30 @@ def verify_reauthored_model(
223
238
 
224
239
  Args:
225
240
  original_model (ModelWrapper): The original model.
226
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
227
- Generative API.
228
- tokenizer (torch.nn.Module): The tokenizer.
241
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
242
+ ai_edge_torch Generative API.
243
+ tokenizer (TokenizerWrapper): The tokenizer.
229
244
  generate_prompts (List[str]): List of the input prompts to generate answers.
245
+ max_new_tokens (int): The maximum number of new tokens to generate.
230
246
  forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
231
247
  forward with.
232
248
  rtol (float): The relative tolerance for the comparison.
233
249
  atol (float): The absolute tolerance for the comparison.
234
250
  """
235
251
  for input_ids in forward_input_ids:
236
- log_msg("Verifying the reauthored model with input IDs:", input_ids)
252
+ logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
237
253
  if verify_with_input_ids(
238
254
  original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
239
255
  ):
240
- log_msg("PASS")
256
+ logging.info("PASS")
241
257
  else:
242
- log_msg("FAILED")
258
+ logging.error("FAILED")
243
259
 
244
260
  for prompts in generate_prompts:
245
- log_msg("Verifying the reauthored model with prompts:", prompts)
261
+ logging.info("Verifying the reauthored model with prompts:%s", prompts)
246
262
  if verify_model_with_prompts(
247
- original_model, reauthored_model, tokenizer, prompts
263
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
248
264
  ):
249
- log_msg("PASS")
265
+ logging.info("PASS")
250
266
  else:
251
- log_msg("FAILED")
267
+ logging.error("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240924"
16
+ __version__ = "0.3.0.dev20240928"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240924
3
+ Version: 0.3.0.dev20240928
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI