ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240920__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (32) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +7 -3
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +6 -4
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
  5. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  6. ai_edge_torch/generative/examples/openelm/verify.py +63 -0
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
  8. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  9. ai_edge_torch/generative/examples/phi/verify.py +63 -0
  10. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
  11. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  12. ai_edge_torch/generative/examples/smollm/verify.py +60 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +62 -0
  16. ai_edge_torch/generative/layers/builder.py +3 -1
  17. ai_edge_torch/generative/layers/model_config.py +3 -0
  18. ai_edge_torch/generative/layers/normalization.py +31 -20
  19. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  20. ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
  21. ai_edge_torch/generative/layers/unet/model_config.py +1 -0
  22. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +2 -2
  24. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
  25. ai_edge_torch/generative/utilities/verifier.py +249 -0
  26. ai_edge_torch/model.py +7 -4
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/RECORD +32 -27
  30. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored TinyLlama-1.1B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Show me the program to add 2 and 3.",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
40
+ )
41
+ # Locate the cached dir.
42
+ cached_config_file = transformers.utils.cached_file(
43
+ checkpoint, transformers.utils.CONFIG_NAME
44
+ )
45
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
47
+ reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
48
+
49
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=wrapper_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ prompts=_PROMPTS.value,
57
+ atol=1e-04,
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ app.run(main)
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
75
75
  zero_centered_gamma=config.zero_centered,
76
76
  )
77
77
  elif config.type == cfg.NormalizationType.LAYER_NORM:
78
- return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
78
+ return normalization.LayerNorm(
79
+ dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ )
79
81
  elif config.type == cfg.NormalizationType.GROUP_NORM:
80
82
  return normalization.GroupNorm(
81
83
  config.group_num, dim, config.epsilon, config.enable_hlfb
@@ -69,6 +69,9 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
+ # Whether to use the input shape to determine the dimension of normalization
73
+ # when type is LAYER_NORM.
74
+ use_input_shape: bool = True
72
75
 
73
76
 
74
77
  @dataclass
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
78
78
  group_num (int): Number of groups to separate the channels into.
79
79
  dim (int): Dimension of the input tensor.
80
80
  eps (float): A small float value to ensure numerical stability (default:
81
- 1e-6).
81
+ 1e-5).
82
82
  enable_hlfb (bool): Whether to convert this normalization into a single
83
83
  op.
84
84
  """
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
112
112
 
113
113
  class LayerNorm(torch.nn.Module):
114
114
 
115
- def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ eps: float = 1e-5,
119
+ enable_hlfb: bool = False,
120
+ use_input_shape: bool = True,
121
+ ):
116
122
  """Initialize the LayerNorm layer.
117
123
 
118
124
  Args:
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
121
127
  1e-6).
122
128
  enable_hlfb (bool): Whether to convert this normalization into a single
123
129
  op.
130
+ use_input_shape (bool): Whether to use the input shape to determine the
131
+ dimension of normalization (default: True).
124
132
  """
125
133
  super().__init__()
126
134
  self.enable_hlfb = enable_hlfb
135
+ self.use_input_shape = use_input_shape
127
136
  self.eps = eps
128
137
  self.weight = torch.nn.Parameter(torch.ones(dim))
129
138
  self.bias = torch.nn.Parameter(torch.ones(dim))
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
139
148
  """
140
149
  if self.enable_hlfb:
141
150
  return layer_norm_with_hlfb(
142
- x,
143
- self.weight,
144
- self.bias,
145
- self.eps,
151
+ x, self.weight, self.bias, self.eps, self.use_input_shape
146
152
  )
153
+
154
+ if self.use_input_shape:
155
+ normalized_shape = x.shape
156
+ weight = self.weight.broadcast_to(x.shape)
157
+ bias = self.bias.broadcast_to(x.shape)
147
158
  else:
148
- return F.layer_norm(
149
- x,
150
- x.shape,
151
- self.weight.broadcast_to(x.shape),
152
- self.bias.broadcast_to(x.shape),
153
- self.eps,
154
- )
159
+ normalized_shape = self.weight.shape
160
+ weight = self.weight
161
+ bias = self.bias
162
+ return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
155
163
 
156
164
 
157
165
  def group_norm_with_hlfb(
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
193
201
  w: torch.Tensor,
194
202
  b: torch.Tensor,
195
203
  eps: float,
204
+ use_input_shape: bool,
196
205
  ):
197
206
  """Layer Normalization with high-level function boundary enabled.
198
207
 
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
201
210
  w (torch.Tensor): The weight tensor for the normalization.
202
211
  b (torch.Tensor): The bias tensor for the normalization.
203
212
  eps (float): A small float value to ensure numerical stability.
213
+ use_input_shape (bool): Whether to use the input shape to determine the
214
+ dimension of normalization.
204
215
 
205
216
  Returns:
206
217
  The output tensor of Layer Normalization.
207
218
  """
208
219
  builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
220
  x, w, b = builder.mark_inputs(x, w, b)
210
- y = F.layer_norm(
211
- x,
212
- x.shape,
213
- weight=w.broadcast_to(x.shape),
214
- bias=b.broadcast_to(x.shape),
215
- eps=eps,
216
- )
221
+ if use_input_shape:
222
+ normalized_shape = x.shape
223
+ w = w.broadcast_to(x.shape)
224
+ b = b.broadcast_to(x.shape)
225
+ else:
226
+ normalized_shape = w.shape
227
+ y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
217
228
  y = builder.mark_outputs(y)
218
229
  return y
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
119
119
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
120
120
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
121
121
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
122
- y = F.scaled_dot_product_attention(
123
- q,
124
- k,
125
- v,
126
- attn_mask=mask,
127
- dropout_p=0.0,
128
- is_causal=mask is None,
129
- scale=scale,
130
- )
122
+ if softcap is None:
123
+ y = F.scaled_dot_product_attention(
124
+ q,
125
+ k,
126
+ v,
127
+ attn_mask=mask,
128
+ dropout_p=0.0,
129
+ is_causal=mask is None,
130
+ scale=scale,
131
+ )
132
+ else:
133
+ q.mul_(scale)
134
+ scores = q @ k.transpose(-1, -2)
135
+ scores = scores / softcap
136
+ scores = torch.tanh(scores)
137
+ scores = scores * softcap
138
+ scores = scores + mask
139
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
140
+ y = torch.matmul(out, v)
131
141
 
132
142
  result = y.transpose(1, 2)
133
143
  result = builder.mark_outputs(result)
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
41
41
  )
42
42
  self.conv_1 = nn.Conv2d(
43
43
  config.in_channels,
44
- config.out_channels,
44
+ config.hidden_channels,
45
45
  kernel_size=3,
46
46
  stride=1,
47
47
  padding=1,
48
48
  )
49
49
  if config.time_embedding_channels is not None:
50
50
  self.time_emb_proj = nn.Linear(
51
- config.time_embedding_channels, config.out_channels
51
+ config.time_embedding_channels, config.hidden_channels
52
52
  )
53
53
  else:
54
54
  self.time_emb_proj = None
55
55
  self.norm_2 = layers_builder.build_norm(
56
- config.out_channels, config.normalization_config
56
+ config.hidden_channels, config.normalization_config
57
57
  )
58
58
  self.conv_2 = nn.Conv2d(
59
- config.out_channels,
59
+ config.hidden_channels,
60
60
  config.out_channels,
61
61
  kernel_size=3,
62
62
  stride=1,
@@ -391,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
391
391
  ResidualBlock2D(
392
392
  unet_cfg.ResidualBlock2DConfig(
393
393
  in_channels=input_channels,
394
+ hidden_channels=config.out_channels,
394
395
  out_channels=config.out_channels,
395
396
  time_embedding_channels=config.time_embedding_channels,
396
397
  normalization_config=config.normalization_config,
@@ -492,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
492
493
  ResidualBlock2D(
493
494
  unet_cfg.ResidualBlock2DConfig(
494
495
  in_channels=input_channels,
496
+ hidden_channels=config.out_channels,
495
497
  out_channels=config.out_channels,
496
498
  time_embedding_channels=config.time_embedding_channels,
497
499
  normalization_config=config.normalization_config,
@@ -602,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
602
604
  ResidualBlock2D(
603
605
  unet_cfg.ResidualBlock2DConfig(
604
606
  in_channels=resnet_in_channels + res_skip_channels,
607
+ hidden_channels=config.out_channels,
605
608
  out_channels=config.out_channels,
606
609
  time_embedding_channels=config.time_embedding_channels,
607
610
  normalization_config=config.normalization_config,
@@ -706,6 +709,7 @@ class MidBlock2D(nn.Module):
706
709
  ResidualBlock2D(
707
710
  unet_cfg.ResidualBlock2DConfig(
708
711
  in_channels=config.in_channels,
712
+ hidden_channels=config.in_channels,
709
713
  out_channels=config.in_channels,
710
714
  time_embedding_channels=config.time_embedding_channels,
711
715
  normalization_config=config.normalization_config,
@@ -724,6 +728,7 @@ class MidBlock2D(nn.Module):
724
728
  ResidualBlock2D(
725
729
  unet_cfg.ResidualBlock2DConfig(
726
730
  in_channels=config.in_channels,
731
+ hidden_channels=config.in_channels,
727
732
  out_channels=config.in_channels,
728
733
  time_embedding_channels=config.time_embedding_channels,
729
734
  normalization_config=config.normalization_config,
@@ -48,6 +48,7 @@ class DownSamplingConfig:
48
48
  @dataclasses.dataclass
49
49
  class ResidualBlock2DConfig:
50
50
  in_channels: int
51
+ hidden_channels: int
51
52
  out_channels: int
52
53
  normalization_config: layers_cfg.NormalizationConfig
53
54
  activation_config: layers_cfg.ActivationConfig
@@ -25,7 +25,7 @@ import numpy as np
25
25
  import torch
26
26
 
27
27
  from absl.testing import absltest as googletest
28
- from tensorflow.lite.python import interpreter
28
+ from ai_edge_litert import interpreter
29
29
 
30
30
 
31
31
  class TestModelConversion(googletest.TestCase):
@@ -28,7 +28,7 @@ import numpy as np
28
28
  import torch
29
29
 
30
30
  from absl.testing import absltest as googletest
31
- from tensorflow.lite.python import interpreter
31
+ from ai_edge_litert import interpreter
32
32
 
33
33
 
34
34
  class TestModelConversion(googletest.TestCase):
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
96
96
  def test_gemma2(self):
97
97
  config = gemma2.get_fake_model_config()
98
98
  pytorch_model = gemma2.Gemma2(config).eval()
99
- self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
100
100
 
101
101
  @googletest.skipIf(
102
102
  ai_edge_config.Config.use_torch_xla,
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
412
412
  ):
413
413
  residual_block_config = unet_config.ResidualBlock2DConfig(
414
414
  in_channels=config.in_channels,
415
+ hidden_channels=config.in_channels,
415
416
  out_channels=config.in_channels,
416
417
  time_embedding_channels=config.time_embedding_channels,
417
418
  normalization_config=config.normalization_config,
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
466
467
  f"{converted_state_param_prefix}.resnets.{i}",
467
468
  unet_config.ResidualBlock2DConfig(
468
469
  in_channels=input_channels,
470
+ hidden_channels=config.out_channels,
469
471
  out_channels=config.out_channels,
470
472
  time_embedding_channels=config.time_embedding_channels,
471
473
  normalization_config=config.normalization_config,
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
508
510
  f"{converted_state_param_prefix}.resnets.{i}",
509
511
  unet_config.ResidualBlock2DConfig(
510
512
  in_channels=input_channels,
513
+ hidden_channels=config.out_channels,
511
514
  out_channels=config.out_channels,
512
515
  time_embedding_channels=config.time_embedding_channels,
513
516
  normalization_config=config.normalization_config,
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
554
557
  f"{converted_state_param_prefix}.resnets.{i}",
555
558
  unet_config.ResidualBlock2DConfig(
556
559
  in_channels=resnet_in_channels + res_skip_channels,
560
+ hidden_channels=config.out_channels,
557
561
  out_channels=config.out_channels,
558
562
  time_embedding_channels=config.time_embedding_channels,
559
563
  normalization_config=config.normalization_config,
@@ -0,0 +1,249 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions to verify the reauthored models."""
17
+
18
+ import datetime
19
+ from typing import List, Optional, Union
20
+
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import numpy as np
23
+ import torch
24
+ import transformers
25
+
26
+
27
+ def log_msg(*args):
28
+ print("[%s]" % datetime.datetime.now(), *args)
29
+
30
+
31
+ class ModelWrapper(torch.nn.Module):
32
+ """A wrapper for the model to be verified, this could be a HuggingFace model
33
+
34
+ or a regular PyTorch model.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ model: torch.nn.Module,
40
+ model_format: str = "huggingface",
41
+ hf_generation_config: Optional[transformers.GenerationConfig] = None,
42
+ ):
43
+ """Initializes the wrapper.
44
+
45
+ Args:
46
+ model (torch.nn.Module): The original model. This could be a model built
47
+ from HuggingFace transformers, or a regular PyTorch model.
48
+ model_format (str): The format of the model. It should be either
49
+ "huggingface" or "pytorch".
50
+ hf_generation_config (transformers.GenerationConfig): The HuggingFace
51
+ generation config. This config will only be used if the underlying model
52
+ is built from HuggingFace transformers.
53
+ """
54
+ super().__init__()
55
+ self.model = model
56
+ self.model_format = model_format
57
+ self.hf_generation_config = hf_generation_config
58
+
59
+ def generate(
60
+ self, inputs: torch.Tensor
61
+ ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
62
+ if self.model_format == "huggingface":
63
+ return self.model.generate(
64
+ inputs=inputs, generation_config=self.hf_generation_config
65
+ )
66
+ else:
67
+ raise NotImplementedError(
68
+ "generate() is not implemented for model format: %s"
69
+ % self.model_format
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ inputs: torch.Tensor,
75
+ ):
76
+ return self.model.forward(inputs)
77
+
78
+
79
+ def forward(
80
+ model: torch.nn.Module,
81
+ tokens: torch.Tensor,
82
+ kv_cache: kv_utils.KVCache,
83
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
84
+ """Forwards the model reauthored with ai_edge_torch Generative API.
85
+
86
+ Args:
87
+ model (torch.nn.Module): The model to forward. It should be a model built
88
+ with ai_edge_torch Generative API.
89
+ tokens (torch.Tensor): The input tokens to forward.
90
+ kv_cache (KVCache): The KV cache to forward.
91
+
92
+ Returns:
93
+ The output logits and the updated KV cache.
94
+ """
95
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
96
+ output = model.forward(tokens, input_pos, kv_cache)
97
+ return output["logits"], output["kv_cache"]
98
+
99
+
100
+ def generate(
101
+ model: torch.nn.Module, prompts: torch.Tensor, response_len: int
102
+ ) -> torch.Tensor:
103
+ """Generates the response to the prompts.
104
+
105
+ It appends tokens output by the model to the prompts and feeds them back to
106
+ the model up to decode_len.
107
+
108
+ Args:
109
+ model (torch.nn.Module): The model to generate. It should be a model built
110
+ with ai_edge_torch Generative API.
111
+ prompts (torch.Tensor): The prompts to generate.
112
+ response_len (int): The number of tokens to generate.
113
+
114
+ Returns:
115
+ The generated tokens.
116
+ """
117
+ input_ids = prompts[0].int().tolist()
118
+ kv_cache = kv_utils.KVCache.from_model_config(model.config)
119
+ for _ in range(response_len - len(input_ids)):
120
+ logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
121
+ generated_token = logits[0][-1].argmax().item()
122
+ input_ids.append(generated_token)
123
+ return torch.tensor([input_ids])
124
+
125
+
126
+ def verify_with_input_ids(
127
+ original_model: ModelWrapper,
128
+ reauthored_model: torch.nn.Module,
129
+ input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
130
+ kv_cache_max_len: int = 1024,
131
+ rtol: float = 1e-05,
132
+ atol: float = 1e-05,
133
+ ) -> bool:
134
+ """Verifies if the model reauthored generates the same output of the oringal.
135
+
136
+ It compares only one outputs from the original and the reauthored model.
137
+
138
+ Args:
139
+ original_model (ModelWrapper): The original model.
140
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
141
+ Generative API.
142
+ input_ids (torch.Tensor): The input token IDs to forward.
143
+ kv_cache_max_len (int): The maximum sequence length of the KV cache.
144
+ rtol (float): The relative tolerance for the comparison.
145
+ atol (float): The absolute tolerance for the comparison.
146
+
147
+ Returns:
148
+ True if the model reauthored generates the same output of the original.
149
+ """
150
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
151
+ input_ids_len = input_ids.shape[1]
152
+ tokens[0, :input_ids_len] = input_ids
153
+
154
+ log_msg("Forwarding the original model...")
155
+ outputs_original = original_model.forward(tokens)
156
+ logits_original = outputs_original.logits[0, input_ids_len - 1, :]
157
+ log_msg("logits_original: ", logits_original)
158
+
159
+ log_msg("Forwarding the reauthored model...")
160
+ kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
161
+ outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
162
+ logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
163
+ log_msg("logits_reauthored:", logits_reauthored)
164
+
165
+ return torch.allclose(
166
+ logits_original, logits_reauthored, rtol=rtol, atol=atol
167
+ )
168
+
169
+
170
+ def verify_model_with_prompts(
171
+ original_model: ModelWrapper,
172
+ reauthored_model: torch.nn.Module,
173
+ tokenizer: torch.nn.Module,
174
+ prompts: str,
175
+ ) -> bool:
176
+ """Verifies if the model reauthored generates the same answer of the oringal.
177
+
178
+ It compares an answer, i.e. multiple continuous outputs generated by the
179
+ original and the reauthored model.
180
+
181
+ Args:
182
+ original_model (ModelWrapper): The original model.
183
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
184
+ Generative API.
185
+ tokenizer (torch.nn.Module): The tokenizer.
186
+ prompts (str): The input prompts to generate answers.
187
+
188
+ Returns:
189
+ True if the model reauthored generates the same answer of the original.
190
+ """
191
+ prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
192
+
193
+ log_msg("Generating answer with the original model...")
194
+ outputs_original = original_model.generate(prompt_tokens)
195
+ response_original = tokenizer.decode(outputs_original[0])
196
+ log_msg("outputs_from_original_model: [[", response_original, "]]")
197
+
198
+ log_msg("Generating answer with the reauthored model...")
199
+ generate_len = len(outputs_original[0])
200
+ outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
201
+ response_reauthored = tokenizer.decode(outputs_reauthored[0])
202
+ log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
203
+
204
+ return response_original == response_reauthored
205
+
206
+
207
+ def verify_reauthored_model(
208
+ original_model: ModelWrapper,
209
+ reauthored_model: torch.nn.Module,
210
+ tokenizer: torch.nn.Module,
211
+ prompts: List[str],
212
+ rtol: float = 1e-05,
213
+ atol: float = 1e-05,
214
+ ):
215
+ """Verifies the reauthored model against the original model.
216
+
217
+ It verifies the reauthored model with two methods:
218
+ 1. It compares the output of the original and the reauthored model with an
219
+ arbitrary input.
220
+ 2. It compares the answer generated by the original and the reauthored model
221
+ with a prompt.
222
+
223
+ It prints out "PASS" or "FAILED" to the console.
224
+
225
+ Args:
226
+ original_model (ModelWrapper): The original model.
227
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
228
+ Generative API.
229
+ tokenizer (torch.nn.Module): The tokenizer.
230
+ prompts (List[str]): List of the input prompts to generate answers.
231
+ rtol (float): The relative tolerance for the comparison.
232
+ atol (float): The absolute tolerance for the comparison.
233
+ """
234
+ log_msg("Verifying the reauthored model with an arbitrary input...")
235
+ if verify_with_input_ids(
236
+ original_model, reauthored_model, rtol=rtol, atol=atol
237
+ ):
238
+ log_msg("PASS")
239
+ else:
240
+ log_msg("FAILED")
241
+
242
+ for p in prompts:
243
+ log_msg("Verifying the reauthored model with prompts:", p)
244
+ if verify_model_with_prompts(
245
+ original_model, reauthored_model, tokenizer, p
246
+ ):
247
+ log_msg("PASS")
248
+ else:
249
+ log_msg("FAILED")
ai_edge_torch/model.py CHANGED
@@ -27,6 +27,8 @@ from typing import Callable
27
27
  import numpy.typing as npt
28
28
  import tensorflow as tf
29
29
 
30
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
31
+
30
32
  DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
31
33
 
32
34
 
@@ -65,7 +67,7 @@ class TfLiteModel(Model):
65
67
  tflite_model: A TFlite serialized object.
66
68
  """
67
69
  self._tflite_model = tflite_model
68
- self._interpreter_builder = lambda: tf.lite.Interpreter(
70
+ self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
69
71
  model_content=self._tflite_model,
70
72
  experimental_default_delegate_latest_features=True,
71
73
  )
@@ -75,12 +77,13 @@ class TfLiteModel(Model):
75
77
  return self._tflite_model
76
78
 
77
79
  def set_interpreter_builder(
78
- self, builder: Callable[[], tf.lite.Interpreter]
80
+ self, builder: Callable[[], tfl_interpreter.Interpreter]
79
81
  ) -> None:
80
82
  """Sets a custom interpreter builder.
81
83
 
82
84
  Args:
83
- builder: A function that returns a `tf.lite.Interpreter` or its subclass.
85
+ builder: A function that returns a `tfl_interpreter.Interpreter` or its
86
+ subclass.
84
87
  """
85
88
  self._interpreter_builder = builder
86
89
 
@@ -166,7 +169,7 @@ class TfLiteModel(Model):
166
169
 
167
170
  # Check if this is indeed a tflite model:
168
171
  try:
169
- interpreter = tf.lite.Interpreter(model_content=model_content)
172
+ interpreter = tfl_interpreter.Interpreter(model_content=model_content)
170
173
  interpreter.get_signature_list()
171
174
  except:
172
175
  return None
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240918"
16
+ __version__ = "0.3.0.dev20240920"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240918
3
+ Version: 0.3.0.dev20240920
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI