ai-edge-torch-nightly 0.3.0.dev20240918__py3-none-any.whl → 0.3.0.dev20240920__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. ai_edge_torch/_convert/test/test_convert.py +7 -3
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
  3. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +6 -4
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
  5. ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
  6. ai_edge_torch/generative/examples/openelm/verify.py +63 -0
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
  8. ai_edge_torch/generative/examples/phi/phi2.py +4 -31
  9. ai_edge_torch/generative/examples/phi/verify.py +63 -0
  10. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
  11. ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
  12. ai_edge_torch/generative/examples/smollm/verify.py +60 -0
  13. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
  14. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
  15. ai_edge_torch/generative/examples/tiny_llama/verify.py +62 -0
  16. ai_edge_torch/generative/layers/builder.py +3 -1
  17. ai_edge_torch/generative/layers/model_config.py +3 -0
  18. ai_edge_torch/generative/layers/normalization.py +31 -20
  19. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
  20. ai_edge_torch/generative/layers/unet/blocks_2d.py +9 -4
  21. ai_edge_torch/generative/layers/unet/model_config.py +1 -0
  22. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +2 -2
  24. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +4 -0
  25. ai_edge_torch/generative/utilities/verifier.py +249 -0
  26. ai_edge_torch/model.py +7 -4
  27. ai_edge_torch/version.py +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/METADATA +1 -1
  29. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/RECORD +32 -27
  30. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/LICENSE +0 -0
  31. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/WHEEL +0 -0
  32. {ai_edge_torch_nightly-0.3.0.dev20240918.dist-info → ai_edge_torch_nightly-0.3.0.dev20240920.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,62 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Verifies the reauthored TinyLlama-1.1B model."""
17
+
18
+ import pathlib
19
+
20
+ from absl import app
21
+ from absl import flags
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.utilities import verifier
24
+ import transformers
25
+
26
+ _PROMPTS = flags.DEFINE_multi_string(
27
+ "prompts",
28
+ "Show me the program to add 2 and 3.",
29
+ "The input prompts to generate answers.",
30
+ )
31
+
32
+
33
+ def main(_):
34
+ checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ verifier.log_msg("Loading the original model from", checkpoint)
36
+ wrapper_model = verifier.ModelWrapper(
37
+ model=transformers.AutoModelForCausalLM.from_pretrained(
38
+ checkpoint, trust_remote_code=True
39
+ ),
40
+ )
41
+ # Locate the cached dir.
42
+ cached_config_file = transformers.utils.cached_file(
43
+ checkpoint, transformers.utils.CONFIG_NAME
44
+ )
45
+ reauthored_checkpoint = pathlib.Path(cached_config_file).parent
46
+ verifier.log_msg("Building the reauthored model from", reauthored_checkpoint)
47
+ reauthored_model = tiny_llama.build_model(reauthored_checkpoint)
48
+
49
+ verifier.log_msg("Loading the tokenizer from", checkpoint)
50
+ tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
51
+
52
+ verifier.verify_reauthored_model(
53
+ original_model=wrapper_model,
54
+ reauthored_model=reauthored_model,
55
+ tokenizer=tokenizer,
56
+ prompts=_PROMPTS.value,
57
+ atol=1e-04,
58
+ )
59
+
60
+
61
+ if __name__ == "__main__":
62
+ app.run(main)
@@ -75,7 +75,9 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
75
75
  zero_centered_gamma=config.zero_centered,
76
76
  )
77
77
  elif config.type == cfg.NormalizationType.LAYER_NORM:
78
- return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
78
+ return normalization.LayerNorm(
79
+ dim, config.epsilon, config.enable_hlfb, config.use_input_shape
80
+ )
79
81
  elif config.type == cfg.NormalizationType.GROUP_NORM:
80
82
  return normalization.GroupNorm(
81
83
  config.group_num, dim, config.epsilon, config.enable_hlfb
@@ -69,6 +69,9 @@ class NormalizationConfig:
69
69
  zero_centered: bool = False
70
70
  # Number of groups used in group normalization.
71
71
  group_num: Optional[float] = None
72
+ # Whether to use the input shape to determine the dimension of normalization
73
+ # when type is LAYER_NORM.
74
+ use_input_shape: bool = True
72
75
 
73
76
 
74
77
  @dataclass
@@ -78,7 +78,7 @@ class GroupNorm(torch.nn.Module):
78
78
  group_num (int): Number of groups to separate the channels into.
79
79
  dim (int): Dimension of the input tensor.
80
80
  eps (float): A small float value to ensure numerical stability (default:
81
- 1e-6).
81
+ 1e-5).
82
82
  enable_hlfb (bool): Whether to convert this normalization into a single
83
83
  op.
84
84
  """
@@ -112,7 +112,13 @@ class GroupNorm(torch.nn.Module):
112
112
 
113
113
  class LayerNorm(torch.nn.Module):
114
114
 
115
- def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
115
+ def __init__(
116
+ self,
117
+ dim: int,
118
+ eps: float = 1e-5,
119
+ enable_hlfb: bool = False,
120
+ use_input_shape: bool = True,
121
+ ):
116
122
  """Initialize the LayerNorm layer.
117
123
 
118
124
  Args:
@@ -121,9 +127,12 @@ class LayerNorm(torch.nn.Module):
121
127
  1e-6).
122
128
  enable_hlfb (bool): Whether to convert this normalization into a single
123
129
  op.
130
+ use_input_shape (bool): Whether to use the input shape to determine the
131
+ dimension of normalization (default: True).
124
132
  """
125
133
  super().__init__()
126
134
  self.enable_hlfb = enable_hlfb
135
+ self.use_input_shape = use_input_shape
127
136
  self.eps = eps
128
137
  self.weight = torch.nn.Parameter(torch.ones(dim))
129
138
  self.bias = torch.nn.Parameter(torch.ones(dim))
@@ -139,19 +148,18 @@ class LayerNorm(torch.nn.Module):
139
148
  """
140
149
  if self.enable_hlfb:
141
150
  return layer_norm_with_hlfb(
142
- x,
143
- self.weight,
144
- self.bias,
145
- self.eps,
151
+ x, self.weight, self.bias, self.eps, self.use_input_shape
146
152
  )
153
+
154
+ if self.use_input_shape:
155
+ normalized_shape = x.shape
156
+ weight = self.weight.broadcast_to(x.shape)
157
+ bias = self.bias.broadcast_to(x.shape)
147
158
  else:
148
- return F.layer_norm(
149
- x,
150
- x.shape,
151
- self.weight.broadcast_to(x.shape),
152
- self.bias.broadcast_to(x.shape),
153
- self.eps,
154
- )
159
+ normalized_shape = self.weight.shape
160
+ weight = self.weight
161
+ bias = self.bias
162
+ return F.layer_norm(x, normalized_shape, weight, bias, self.eps)
155
163
 
156
164
 
157
165
  def group_norm_with_hlfb(
@@ -193,6 +201,7 @@ def layer_norm_with_hlfb(
193
201
  w: torch.Tensor,
194
202
  b: torch.Tensor,
195
203
  eps: float,
204
+ use_input_shape: bool,
196
205
  ):
197
206
  """Layer Normalization with high-level function boundary enabled.
198
207
 
@@ -201,18 +210,20 @@ def layer_norm_with_hlfb(
201
210
  w (torch.Tensor): The weight tensor for the normalization.
202
211
  b (torch.Tensor): The bias tensor for the normalization.
203
212
  eps (float): A small float value to ensure numerical stability.
213
+ use_input_shape (bool): Whether to use the input shape to determine the
214
+ dimension of normalization.
204
215
 
205
216
  Returns:
206
217
  The output tensor of Layer Normalization.
207
218
  """
208
219
  builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
220
  x, w, b = builder.mark_inputs(x, w, b)
210
- y = F.layer_norm(
211
- x,
212
- x.shape,
213
- weight=w.broadcast_to(x.shape),
214
- bias=b.broadcast_to(x.shape),
215
- eps=eps,
216
- )
221
+ if use_input_shape:
222
+ normalized_shape = x.shape
223
+ w = w.broadcast_to(x.shape)
224
+ b = b.broadcast_to(x.shape)
225
+ else:
226
+ normalized_shape = w.shape
227
+ y = F.layer_norm(x, normalized_shape, w, b, eps=eps)
217
228
  y = builder.mark_outputs(y)
218
229
  return y
@@ -119,15 +119,25 @@ def scaled_dot_product_attention_with_hlfb(
119
119
  # Handle the GQA case, where q.shape[1] % k.shape[1] == 0.
120
120
  k = k.repeat_interleave(q.shape[1] // k.shape[1], dim=1)
121
121
  v = v.repeat_interleave(q.shape[1] // v.shape[1], dim=1)
122
- y = F.scaled_dot_product_attention(
123
- q,
124
- k,
125
- v,
126
- attn_mask=mask,
127
- dropout_p=0.0,
128
- is_causal=mask is None,
129
- scale=scale,
130
- )
122
+ if softcap is None:
123
+ y = F.scaled_dot_product_attention(
124
+ q,
125
+ k,
126
+ v,
127
+ attn_mask=mask,
128
+ dropout_p=0.0,
129
+ is_causal=mask is None,
130
+ scale=scale,
131
+ )
132
+ else:
133
+ q.mul_(scale)
134
+ scores = q @ k.transpose(-1, -2)
135
+ scores = scores / softcap
136
+ scores = torch.tanh(scores)
137
+ scores = scores * softcap
138
+ scores = scores + mask
139
+ out = F.softmax(scores.float(), dim=-1).type_as(q)
140
+ y = torch.matmul(out, v)
131
141
 
132
142
  result = y.transpose(1, 2)
133
143
  result = builder.mark_outputs(result)
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
41
41
  )
42
42
  self.conv_1 = nn.Conv2d(
43
43
  config.in_channels,
44
- config.out_channels,
44
+ config.hidden_channels,
45
45
  kernel_size=3,
46
46
  stride=1,
47
47
  padding=1,
48
48
  )
49
49
  if config.time_embedding_channels is not None:
50
50
  self.time_emb_proj = nn.Linear(
51
- config.time_embedding_channels, config.out_channels
51
+ config.time_embedding_channels, config.hidden_channels
52
52
  )
53
53
  else:
54
54
  self.time_emb_proj = None
55
55
  self.norm_2 = layers_builder.build_norm(
56
- config.out_channels, config.normalization_config
56
+ config.hidden_channels, config.normalization_config
57
57
  )
58
58
  self.conv_2 = nn.Conv2d(
59
- config.out_channels,
59
+ config.hidden_channels,
60
60
  config.out_channels,
61
61
  kernel_size=3,
62
62
  stride=1,
@@ -391,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
391
391
  ResidualBlock2D(
392
392
  unet_cfg.ResidualBlock2DConfig(
393
393
  in_channels=input_channels,
394
+ hidden_channels=config.out_channels,
394
395
  out_channels=config.out_channels,
395
396
  time_embedding_channels=config.time_embedding_channels,
396
397
  normalization_config=config.normalization_config,
@@ -492,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
492
493
  ResidualBlock2D(
493
494
  unet_cfg.ResidualBlock2DConfig(
494
495
  in_channels=input_channels,
496
+ hidden_channels=config.out_channels,
495
497
  out_channels=config.out_channels,
496
498
  time_embedding_channels=config.time_embedding_channels,
497
499
  normalization_config=config.normalization_config,
@@ -602,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
602
604
  ResidualBlock2D(
603
605
  unet_cfg.ResidualBlock2DConfig(
604
606
  in_channels=resnet_in_channels + res_skip_channels,
607
+ hidden_channels=config.out_channels,
605
608
  out_channels=config.out_channels,
606
609
  time_embedding_channels=config.time_embedding_channels,
607
610
  normalization_config=config.normalization_config,
@@ -706,6 +709,7 @@ class MidBlock2D(nn.Module):
706
709
  ResidualBlock2D(
707
710
  unet_cfg.ResidualBlock2DConfig(
708
711
  in_channels=config.in_channels,
712
+ hidden_channels=config.in_channels,
709
713
  out_channels=config.in_channels,
710
714
  time_embedding_channels=config.time_embedding_channels,
711
715
  normalization_config=config.normalization_config,
@@ -724,6 +728,7 @@ class MidBlock2D(nn.Module):
724
728
  ResidualBlock2D(
725
729
  unet_cfg.ResidualBlock2DConfig(
726
730
  in_channels=config.in_channels,
731
+ hidden_channels=config.in_channels,
727
732
  out_channels=config.in_channels,
728
733
  time_embedding_channels=config.time_embedding_channels,
729
734
  normalization_config=config.normalization_config,
@@ -48,6 +48,7 @@ class DownSamplingConfig:
48
48
  @dataclasses.dataclass
49
49
  class ResidualBlock2DConfig:
50
50
  in_channels: int
51
+ hidden_channels: int
51
52
  out_channels: int
52
53
  normalization_config: layers_cfg.NormalizationConfig
53
54
  activation_config: layers_cfg.ActivationConfig
@@ -25,7 +25,7 @@ import numpy as np
25
25
  import torch
26
26
 
27
27
  from absl.testing import absltest as googletest
28
- from tensorflow.lite.python import interpreter
28
+ from ai_edge_litert import interpreter
29
29
 
30
30
 
31
31
  class TestModelConversion(googletest.TestCase):
@@ -28,7 +28,7 @@ import numpy as np
28
28
  import torch
29
29
 
30
30
  from absl.testing import absltest as googletest
31
- from tensorflow.lite.python import interpreter
31
+ from ai_edge_litert import interpreter
32
32
 
33
33
 
34
34
  class TestModelConversion(googletest.TestCase):
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
96
96
  def test_gemma2(self):
97
97
  config = gemma2.get_fake_model_config()
98
98
  pytorch_model = gemma2.Gemma2(config).eval()
99
- self._test_model(config, pytorch_model, "prefill", atol=1e-1, rtol=1e-3)
99
+ self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
100
100
 
101
101
  @googletest.skipIf(
102
102
  ai_edge_config.Config.use_torch_xla,
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
412
412
  ):
413
413
  residual_block_config = unet_config.ResidualBlock2DConfig(
414
414
  in_channels=config.in_channels,
415
+ hidden_channels=config.in_channels,
415
416
  out_channels=config.in_channels,
416
417
  time_embedding_channels=config.time_embedding_channels,
417
418
  normalization_config=config.normalization_config,
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
466
467
  f"{converted_state_param_prefix}.resnets.{i}",
467
468
  unet_config.ResidualBlock2DConfig(
468
469
  in_channels=input_channels,
470
+ hidden_channels=config.out_channels,
469
471
  out_channels=config.out_channels,
470
472
  time_embedding_channels=config.time_embedding_channels,
471
473
  normalization_config=config.normalization_config,
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
508
510
  f"{converted_state_param_prefix}.resnets.{i}",
509
511
  unet_config.ResidualBlock2DConfig(
510
512
  in_channels=input_channels,
513
+ hidden_channels=config.out_channels,
511
514
  out_channels=config.out_channels,
512
515
  time_embedding_channels=config.time_embedding_channels,
513
516
  normalization_config=config.normalization_config,
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
554
557
  f"{converted_state_param_prefix}.resnets.{i}",
555
558
  unet_config.ResidualBlock2DConfig(
556
559
  in_channels=resnet_in_channels + res_skip_channels,
560
+ hidden_channels=config.out_channels,
557
561
  out_channels=config.out_channels,
558
562
  time_embedding_channels=config.time_embedding_channels,
559
563
  normalization_config=config.normalization_config,
@@ -0,0 +1,249 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions to verify the reauthored models."""
17
+
18
+ import datetime
19
+ from typing import List, Optional, Union
20
+
21
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ import numpy as np
23
+ import torch
24
+ import transformers
25
+
26
+
27
+ def log_msg(*args):
28
+ print("[%s]" % datetime.datetime.now(), *args)
29
+
30
+
31
+ class ModelWrapper(torch.nn.Module):
32
+ """A wrapper for the model to be verified, this could be a HuggingFace model
33
+
34
+ or a regular PyTorch model.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ model: torch.nn.Module,
40
+ model_format: str = "huggingface",
41
+ hf_generation_config: Optional[transformers.GenerationConfig] = None,
42
+ ):
43
+ """Initializes the wrapper.
44
+
45
+ Args:
46
+ model (torch.nn.Module): The original model. This could be a model built
47
+ from HuggingFace transformers, or a regular PyTorch model.
48
+ model_format (str): The format of the model. It should be either
49
+ "huggingface" or "pytorch".
50
+ hf_generation_config (transformers.GenerationConfig): The HuggingFace
51
+ generation config. This config will only be used if the underlying model
52
+ is built from HuggingFace transformers.
53
+ """
54
+ super().__init__()
55
+ self.model = model
56
+ self.model_format = model_format
57
+ self.hf_generation_config = hf_generation_config
58
+
59
+ def generate(
60
+ self, inputs: torch.Tensor
61
+ ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
62
+ if self.model_format == "huggingface":
63
+ return self.model.generate(
64
+ inputs=inputs, generation_config=self.hf_generation_config
65
+ )
66
+ else:
67
+ raise NotImplementedError(
68
+ "generate() is not implemented for model format: %s"
69
+ % self.model_format
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ inputs: torch.Tensor,
75
+ ):
76
+ return self.model.forward(inputs)
77
+
78
+
79
+ def forward(
80
+ model: torch.nn.Module,
81
+ tokens: torch.Tensor,
82
+ kv_cache: kv_utils.KVCache,
83
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
84
+ """Forwards the model reauthored with ai_edge_torch Generative API.
85
+
86
+ Args:
87
+ model (torch.nn.Module): The model to forward. It should be a model built
88
+ with ai_edge_torch Generative API.
89
+ tokens (torch.Tensor): The input tokens to forward.
90
+ kv_cache (KVCache): The KV cache to forward.
91
+
92
+ Returns:
93
+ The output logits and the updated KV cache.
94
+ """
95
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
96
+ output = model.forward(tokens, input_pos, kv_cache)
97
+ return output["logits"], output["kv_cache"]
98
+
99
+
100
+ def generate(
101
+ model: torch.nn.Module, prompts: torch.Tensor, response_len: int
102
+ ) -> torch.Tensor:
103
+ """Generates the response to the prompts.
104
+
105
+ It appends tokens output by the model to the prompts and feeds them back to
106
+ the model up to decode_len.
107
+
108
+ Args:
109
+ model (torch.nn.Module): The model to generate. It should be a model built
110
+ with ai_edge_torch Generative API.
111
+ prompts (torch.Tensor): The prompts to generate.
112
+ response_len (int): The number of tokens to generate.
113
+
114
+ Returns:
115
+ The generated tokens.
116
+ """
117
+ input_ids = prompts[0].int().tolist()
118
+ kv_cache = kv_utils.KVCache.from_model_config(model.config)
119
+ for _ in range(response_len - len(input_ids)):
120
+ logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
121
+ generated_token = logits[0][-1].argmax().item()
122
+ input_ids.append(generated_token)
123
+ return torch.tensor([input_ids])
124
+
125
+
126
+ def verify_with_input_ids(
127
+ original_model: ModelWrapper,
128
+ reauthored_model: torch.nn.Module,
129
+ input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
130
+ kv_cache_max_len: int = 1024,
131
+ rtol: float = 1e-05,
132
+ atol: float = 1e-05,
133
+ ) -> bool:
134
+ """Verifies if the model reauthored generates the same output of the oringal.
135
+
136
+ It compares only one outputs from the original and the reauthored model.
137
+
138
+ Args:
139
+ original_model (ModelWrapper): The original model.
140
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
141
+ Generative API.
142
+ input_ids (torch.Tensor): The input token IDs to forward.
143
+ kv_cache_max_len (int): The maximum sequence length of the KV cache.
144
+ rtol (float): The relative tolerance for the comparison.
145
+ atol (float): The absolute tolerance for the comparison.
146
+
147
+ Returns:
148
+ True if the model reauthored generates the same output of the original.
149
+ """
150
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
151
+ input_ids_len = input_ids.shape[1]
152
+ tokens[0, :input_ids_len] = input_ids
153
+
154
+ log_msg("Forwarding the original model...")
155
+ outputs_original = original_model.forward(tokens)
156
+ logits_original = outputs_original.logits[0, input_ids_len - 1, :]
157
+ log_msg("logits_original: ", logits_original)
158
+
159
+ log_msg("Forwarding the reauthored model...")
160
+ kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
161
+ outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
162
+ logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
163
+ log_msg("logits_reauthored:", logits_reauthored)
164
+
165
+ return torch.allclose(
166
+ logits_original, logits_reauthored, rtol=rtol, atol=atol
167
+ )
168
+
169
+
170
+ def verify_model_with_prompts(
171
+ original_model: ModelWrapper,
172
+ reauthored_model: torch.nn.Module,
173
+ tokenizer: torch.nn.Module,
174
+ prompts: str,
175
+ ) -> bool:
176
+ """Verifies if the model reauthored generates the same answer of the oringal.
177
+
178
+ It compares an answer, i.e. multiple continuous outputs generated by the
179
+ original and the reauthored model.
180
+
181
+ Args:
182
+ original_model (ModelWrapper): The original model.
183
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
184
+ Generative API.
185
+ tokenizer (torch.nn.Module): The tokenizer.
186
+ prompts (str): The input prompts to generate answers.
187
+
188
+ Returns:
189
+ True if the model reauthored generates the same answer of the original.
190
+ """
191
+ prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
192
+
193
+ log_msg("Generating answer with the original model...")
194
+ outputs_original = original_model.generate(prompt_tokens)
195
+ response_original = tokenizer.decode(outputs_original[0])
196
+ log_msg("outputs_from_original_model: [[", response_original, "]]")
197
+
198
+ log_msg("Generating answer with the reauthored model...")
199
+ generate_len = len(outputs_original[0])
200
+ outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
201
+ response_reauthored = tokenizer.decode(outputs_reauthored[0])
202
+ log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
203
+
204
+ return response_original == response_reauthored
205
+
206
+
207
+ def verify_reauthored_model(
208
+ original_model: ModelWrapper,
209
+ reauthored_model: torch.nn.Module,
210
+ tokenizer: torch.nn.Module,
211
+ prompts: List[str],
212
+ rtol: float = 1e-05,
213
+ atol: float = 1e-05,
214
+ ):
215
+ """Verifies the reauthored model against the original model.
216
+
217
+ It verifies the reauthored model with two methods:
218
+ 1. It compares the output of the original and the reauthored model with an
219
+ arbitrary input.
220
+ 2. It compares the answer generated by the original and the reauthored model
221
+ with a prompt.
222
+
223
+ It prints out "PASS" or "FAILED" to the console.
224
+
225
+ Args:
226
+ original_model (ModelWrapper): The original model.
227
+ reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
228
+ Generative API.
229
+ tokenizer (torch.nn.Module): The tokenizer.
230
+ prompts (List[str]): List of the input prompts to generate answers.
231
+ rtol (float): The relative tolerance for the comparison.
232
+ atol (float): The absolute tolerance for the comparison.
233
+ """
234
+ log_msg("Verifying the reauthored model with an arbitrary input...")
235
+ if verify_with_input_ids(
236
+ original_model, reauthored_model, rtol=rtol, atol=atol
237
+ ):
238
+ log_msg("PASS")
239
+ else:
240
+ log_msg("FAILED")
241
+
242
+ for p in prompts:
243
+ log_msg("Verifying the reauthored model with prompts:", p)
244
+ if verify_model_with_prompts(
245
+ original_model, reauthored_model, tokenizer, p
246
+ ):
247
+ log_msg("PASS")
248
+ else:
249
+ log_msg("FAILED")
ai_edge_torch/model.py CHANGED
@@ -27,6 +27,8 @@ from typing import Callable
27
27
  import numpy.typing as npt
28
28
  import tensorflow as tf
29
29
 
30
+ from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
31
+
30
32
  DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
31
33
 
32
34
 
@@ -65,7 +67,7 @@ class TfLiteModel(Model):
65
67
  tflite_model: A TFlite serialized object.
66
68
  """
67
69
  self._tflite_model = tflite_model
68
- self._interpreter_builder = lambda: tf.lite.Interpreter(
70
+ self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
69
71
  model_content=self._tflite_model,
70
72
  experimental_default_delegate_latest_features=True,
71
73
  )
@@ -75,12 +77,13 @@ class TfLiteModel(Model):
75
77
  return self._tflite_model
76
78
 
77
79
  def set_interpreter_builder(
78
- self, builder: Callable[[], tf.lite.Interpreter]
80
+ self, builder: Callable[[], tfl_interpreter.Interpreter]
79
81
  ) -> None:
80
82
  """Sets a custom interpreter builder.
81
83
 
82
84
  Args:
83
- builder: A function that returns a `tf.lite.Interpreter` or its subclass.
85
+ builder: A function that returns a `tfl_interpreter.Interpreter` or its
86
+ subclass.
84
87
  """
85
88
  self._interpreter_builder = builder
86
89
 
@@ -166,7 +169,7 @@ class TfLiteModel(Model):
166
169
 
167
170
  # Check if this is indeed a tflite model:
168
171
  try:
169
- interpreter = tf.lite.Interpreter(model_content=model_content)
172
+ interpreter = tfl_interpreter.Interpreter(model_content=model_content)
170
173
  interpreter.get_signature_list()
171
174
  except:
172
175
  return None
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240918"
16
+ __version__ = "0.3.0.dev20240920"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240918
3
+ Version: 0.3.0.dev20240920
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI