ai-edge-torch-nightly 0.3.0.dev20240829__py3-none-any.whl → 0.3.0.dev20240831__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -159,15 +159,44 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
159
159
  return config
160
160
 
161
161
 
162
- def get_fake_model_config_2b_for_test(**kwargs) -> cfg.ModelConfig:
163
- config = get_model_config_2b(**kwargs)
164
- config.num_layers = 2
162
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
163
+ attn_config = cfg.AttentionConfig(
164
+ num_heads=8,
165
+ head_dim=256,
166
+ num_query_groups=1,
167
+ rotary_percentage=1.0,
168
+ )
169
+ ff_config = cfg.FeedForwardConfig(
170
+ type=cfg.FeedForwardType.GATED,
171
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
172
+ intermediate_size=128,
173
+ )
174
+ norm_config = cfg.NormalizationConfig(
175
+ type=cfg.NormalizationType.RMS_NORM,
176
+ epsilon=1e-6,
177
+ zero_centered=True,
178
+ )
179
+ config = cfg.ModelConfig(
180
+ vocab_size=128,
181
+ num_layers=2,
182
+ max_seq_len=2 * kv_cache_max_len,
183
+ embedding_dim=2048,
184
+ kv_cache_max_len=kv_cache_max_len,
185
+ attn_config=attn_config,
186
+ ff_config=ff_config,
187
+ pre_attention_norm_config=norm_config,
188
+ post_attention_norm_config=norm_config,
189
+ final_norm_config=norm_config,
190
+ parallel_residual=False,
191
+ lm_head_use_bias=False,
192
+ enable_hlfb=True,
193
+ )
165
194
  return config
166
195
 
167
196
 
168
197
  def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
169
198
  config = (
170
- get_fake_model_config_2b_for_test(**kwargs)
199
+ get_fake_model_config(**kwargs)
171
200
  if test_model
172
201
  else get_model_config_2b(**kwargs)
173
202
  )
@@ -147,9 +147,39 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
147
  return config
148
148
 
149
149
 
150
- def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
151
- config = get_model_config_2b()
152
- config.num_layers = 2
150
+ # TODO(b/363021962): Clean up this part to streamline fake model config generation.
151
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
152
+ attn_config = cfg.AttentionConfig(
153
+ num_heads=8,
154
+ head_dim=256,
155
+ num_query_groups=1,
156
+ rotary_percentage=1.0,
157
+ )
158
+ ff_config = cfg.FeedForwardConfig(
159
+ type=cfg.FeedForwardType.GATED,
160
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
161
+ intermediate_size=128,
162
+ )
163
+ norm_config = cfg.NormalizationConfig(
164
+ type=cfg.NormalizationType.RMS_NORM,
165
+ epsilon=1e-6,
166
+ zero_centered=True,
167
+ )
168
+ config = cfg.ModelConfig(
169
+ vocab_size=128,
170
+ num_layers=2,
171
+ max_seq_len=2 * kv_cache_max_len,
172
+ embedding_dim=2048,
173
+ kv_cache_max_len=kv_cache_max_len,
174
+ attn_config=attn_config,
175
+ ff_config=ff_config,
176
+ pre_attention_norm_config=norm_config,
177
+ post_attention_norm_config=norm_config,
178
+ final_norm_config=norm_config,
179
+ parallel_residual=False,
180
+ lm_head_use_bias=False,
181
+ enable_hlfb=True,
182
+ )
153
183
  return config
154
184
 
155
185
 
@@ -209,9 +209,47 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
209
209
  return config
210
210
 
211
211
 
212
- def get_fake_model_config_2b_for_test() -> cfg.ModelConfig:
213
- config = get_model_config_2b()
214
- config.num_layers = 2
212
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
213
+ attn_config = cfg.AttentionConfig(
214
+ num_heads=4,
215
+ head_dim=64,
216
+ num_query_groups=4,
217
+ rotary_percentage=1.0,
218
+ qkv_transpose_before_split=True,
219
+ logit_softcap=50.0,
220
+ sliding_window_size=64,
221
+ attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
222
+ * 13,
223
+ )
224
+
225
+ norm_config = cfg.NormalizationConfig(
226
+ type=cfg.NormalizationType.RMS_NORM,
227
+ epsilon=1e-6,
228
+ zero_centered=True,
229
+ )
230
+ ff_config = cfg.FeedForwardConfig(
231
+ type=cfg.FeedForwardType.GATED,
232
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
233
+ intermediate_size=128,
234
+ pre_ff_norm_config=norm_config,
235
+ post_ff_norm_config=norm_config,
236
+ )
237
+ config = cfg.ModelConfig(
238
+ vocab_size=128,
239
+ num_layers=2,
240
+ max_seq_len=2 * kv_cache_max_len,
241
+ embedding_dim=128,
242
+ kv_cache_max_len=kv_cache_max_len,
243
+ attn_config=attn_config,
244
+ ff_config=ff_config,
245
+ pre_attention_norm_config=norm_config,
246
+ post_attention_norm_config=norm_config,
247
+ final_norm_config=norm_config,
248
+ parallel_residual=False,
249
+ lm_head_use_bias=False,
250
+ enable_hlfb=True,
251
+ final_logit_softcap=30.0,
252
+ )
215
253
  return config
216
254
 
217
255
 
@@ -139,9 +139,36 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
139
139
  return config
140
140
 
141
141
 
142
- def get_fake_model_config_for_test() -> cfg.ModelConfig:
143
- config = get_model_config()
144
- config.num_layers = 2
142
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
143
+ attn_config = cfg.AttentionConfig(
144
+ num_heads=16,
145
+ head_dim=80,
146
+ num_query_groups=4,
147
+ rotary_percentage=0.4,
148
+ qkv_use_bias=True,
149
+ output_proj_use_bias=True,
150
+ )
151
+ ff_config = cfg.FeedForwardConfig(
152
+ type=cfg.FeedForwardType.SEQUENTIAL,
153
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
154
+ intermediate_size=128,
155
+ use_bias=True,
156
+ )
157
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
158
+ config = cfg.ModelConfig(
159
+ vocab_size=128,
160
+ num_layers=2,
161
+ max_seq_len=2 * kv_cache_max_len,
162
+ kv_cache_max_len=kv_cache_max_len,
163
+ embedding_dim=128,
164
+ attn_config=attn_config,
165
+ ff_config=ff_config,
166
+ pre_attention_norm_config=norm_config,
167
+ final_norm_config=norm_config,
168
+ parallel_residual=True,
169
+ lm_head_use_bias=True,
170
+ enable_hlfb=True,
171
+ )
145
172
  return config
146
173
 
147
174
 
@@ -137,11 +137,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
137
137
  return config
138
138
 
139
139
 
140
- def get_fake_model_config_for_test() -> cfg.ModelConfig:
140
+ def get_fake_model_config() -> cfg.ModelConfig:
141
141
  config = get_model_config()
142
142
  config.vocab_size = 128
143
143
  config.num_layers = 2
144
- config.ff_config.intermediate_size = 256
144
+ config.ff_config.intermediate_size = 64
145
145
  return config
146
146
 
147
147
 
@@ -22,7 +22,7 @@ import torch
22
22
 
23
23
  def main():
24
24
  # Build a PyTorch model as usual
25
- config = gemma.get_fake_model_config_2b_for_test()
25
+ config = gemma.get_fake_model_config()
26
26
  model = gemma.Gemma(config)
27
27
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
28
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
@@ -16,6 +16,7 @@
16
16
  import copy
17
17
 
18
18
  import ai_edge_torch
19
+ from ai_edge_torch import config as ai_edge_config
19
20
  from ai_edge_torch.generative.examples.gemma import gemma, gemma2
20
21
  from ai_edge_torch.generative.examples.phi2 import phi2
21
22
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
@@ -25,11 +26,27 @@ import numpy as np
25
26
  import torch
26
27
 
27
28
  from absl.testing import absltest as googletest
29
+ from tensorflow.lite.python import interpreter
28
30
 
29
31
 
30
32
  class TestModelConversion(googletest.TestCase):
31
33
  """Unit tests that check for model conversion and correctness."""
32
34
 
35
+ def setUp(self):
36
+ super().setUp()
37
+ # Builder function for an Interpreter that supports custom ops.
38
+ self._interpreter_builder = (
39
+ lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
40
+ custom_op_registerers=["GenAIOpsRegisterer"],
41
+ model_content=tflite_model,
42
+ experimental_default_delegate_latest_features=True,
43
+ )
44
+ )
45
+
46
+ @googletest.skipIf(
47
+ ai_edge_config.Config.use_torch_xla,
48
+ reason="tests with custom ops are not supported on oss",
49
+ )
33
50
  def test_toy_model_with_kv_cache(self):
34
51
  config = toy_model_with_kv_cache.get_model_config()
35
52
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
@@ -38,22 +55,27 @@ class TestModelConversion(googletest.TestCase):
38
55
  )
39
56
 
40
57
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
58
+ edge_model.set_interpreter_builder(
59
+ self._interpreter_builder(edge_model.tflite_model())
60
+ )
41
61
 
42
- # TODO: b/338288901 - re-enable test to check output tensors.
43
- skip_output_check = True
44
- if not skip_output_check:
45
- self.assertTrue(
46
- model_coverage.compare_tflite_torch(
47
- edge_model,
48
- pytorch_model,
49
- (idx, input_pos),
50
- num_valid_inputs=1,
51
- atol=1e-5,
52
- rtol=1e-5,
53
- )
54
- )
62
+ self.assertTrue(
63
+ model_coverage.compare_tflite_torch(
64
+ edge_model,
65
+ pytorch_model,
66
+ (idx, input_pos),
67
+ num_valid_inputs=1,
68
+ atol=1e-5,
69
+ rtol=1e-5,
70
+ )
71
+ )
55
72
 
73
+ @googletest.skipIf(
74
+ ai_edge_config.Config.use_torch_xla,
75
+ reason="tests with custom ops are not supported on oss",
76
+ )
56
77
  def test_toy_model_with_multi_batches(self):
78
+ self.skipTest("b/362842043")
57
79
  config = toy_model_with_kv_cache.get_model_config()
58
80
  config.batch_size = 2
59
81
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
@@ -62,21 +84,25 @@ class TestModelConversion(googletest.TestCase):
62
84
  )
63
85
 
64
86
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
87
+ edge_model.set_interpreter_builder(
88
+ self._interpreter_builder(edge_model.tflite_model())
89
+ )
65
90
 
66
- # TODO: b/338288901 - re-enable test to check output tensors.
67
- skip_output_check = True
68
- if not skip_output_check:
69
- self.assertTrue(
70
- model_coverage.compare_tflite_torch(
71
- edge_model,
72
- pytorch_model,
73
- (idx, input_pos),
74
- num_valid_inputs=1,
75
- atol=1e-5,
76
- rtol=1e-5,
77
- )
78
- )
91
+ self.assertTrue(
92
+ model_coverage.compare_tflite_torch(
93
+ edge_model,
94
+ pytorch_model,
95
+ (idx, input_pos),
96
+ num_valid_inputs=1,
97
+ atol=1e-5,
98
+ rtol=1e-5,
99
+ )
100
+ )
79
101
 
102
+ @googletest.skipIf(
103
+ ai_edge_config.Config.use_torch_xla,
104
+ reason="tests with custom ops are not supported on oss",
105
+ )
80
106
  def test_toy_model_with_kv_cache_with_hlfb(self):
81
107
  config = toy_model_with_kv_cache.get_model_config()
82
108
  config.enable_hlfb = True
@@ -86,49 +112,27 @@ class TestModelConversion(googletest.TestCase):
86
112
  )
87
113
 
88
114
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
115
+ edge_model.set_interpreter_builder(
116
+ self._interpreter_builder(edge_model.tflite_model())
117
+ )
89
118
 
90
- # TODO: b/338288901 - re-enable test to check output tensors.
91
- skip_output_check = True
92
- if not skip_output_check:
93
- self.assertTrue(
94
- model_coverage.compare_tflite_torch(
95
- edge_model,
96
- pytorch_model,
97
- (idx, input_pos),
98
- num_valid_inputs=1,
99
- atol=1e-5,
100
- rtol=1e-5,
101
- )
102
- )
103
-
104
- def test_tiny_llama(self):
105
- self.skipTest("b/338288901")
106
- config = tiny_llama.get_fake_model_config_for_test()
107
- pytorch_model = tiny_llama.TinyLLamma(config).eval()
108
-
109
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
110
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
111
- tokens[0, :4] = idx
112
- input_pos = torch.arange(0, 10)
113
-
114
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
115
-
116
- # TODO: b/338288901 - re-enable test to check output tensors.
117
- skip_output_check = True
118
- if not skip_output_check:
119
- self.assertTrue(
120
- model_coverage.compare_tflite_torch(
121
- edge_model,
122
- pytorch_model,
123
- (tokens, input_pos),
124
- num_valid_inputs=1,
125
- atol=1e-5,
126
- rtol=1e-5,
127
- )
128
- )
119
+ self.assertTrue(
120
+ model_coverage.compare_tflite_torch(
121
+ edge_model,
122
+ pytorch_model,
123
+ (idx, input_pos),
124
+ num_valid_inputs=1,
125
+ atol=1e-5,
126
+ rtol=1e-5,
127
+ )
128
+ )
129
129
 
130
+ @googletest.skipIf(
131
+ ai_edge_config.Config.use_torch_xla,
132
+ reason="tests with custom ops are not supported on oss",
133
+ )
130
134
  def test_tiny_llama_multisig(self):
131
- config = tiny_llama.get_fake_model_config_for_test()
135
+ config = tiny_llama.get_fake_model_config()
132
136
  pytorch_model = tiny_llama.TinyLLamma(config).eval()
133
137
 
134
138
  # prefill
@@ -149,22 +153,25 @@ class TestModelConversion(googletest.TestCase):
149
153
  .signature("decode", pytorch_model, (decode_token, decode_input_pos))
150
154
  .convert()
151
155
  )
156
+ edge_model.set_interpreter_builder(
157
+ self._interpreter_builder(edge_model.tflite_model())
158
+ )
152
159
 
153
- # TODO: b/338288901 - re-enable test to check output tensors.
154
- skip_output_check = True
155
- if not skip_output_check:
156
- copied_model = copy.deepcopy(pytorch_model)
160
+ copied_model = copy.deepcopy(pytorch_model)
157
161
 
158
- self.assertTrue(
159
- model_coverage.compare_tflite_torch(
160
- edge_model,
161
- pytorch_model,
162
- (prefill_tokens, prefill_input_pos),
163
- signature_name="prefill",
164
- num_valid_inputs=1,
165
- )
166
- )
162
+ self.assertTrue(
163
+ model_coverage.compare_tflite_torch(
164
+ edge_model,
165
+ pytorch_model,
166
+ (prefill_tokens, prefill_input_pos),
167
+ signature_name="prefill",
168
+ num_valid_inputs=1,
169
+ )
170
+ )
167
171
 
172
+ # TODO(b/362840003): figure why this decode output has big numerical diff.
173
+ skip_output_check = True
174
+ if not skip_output_check:
168
175
  self.assertTrue(
169
176
  model_coverage.compare_tflite_torch(
170
177
  edge_model,
@@ -175,87 +182,6 @@ class TestModelConversion(googletest.TestCase):
175
182
  )
176
183
  )
177
184
 
178
- def test_gemma(self):
179
- self.skipTest("b/338288901")
180
- config = gemma.get_fake_model_config_2b_for_test()
181
- model = gemma.Gemma(config)
182
-
183
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
184
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
185
- tokens[0, :4] = idx
186
- input_pos = torch.arange(0, 10)
187
-
188
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
189
-
190
- # TODO: b/338288901 - re-enable test to check output tensors.
191
- skip_output_check = True
192
- if not skip_output_check:
193
- # TODO(talumbau, haoliang): debug numerical diff.
194
- self.assertTrue(
195
- model_coverage.compare_tflite_torch(
196
- edge_model,
197
- model,
198
- (tokens, input_pos),
199
- num_valid_inputs=1,
200
- atol=1e-2,
201
- rtol=1e-5,
202
- )
203
- )
204
-
205
- def test_gemma2(self):
206
- self.skipTest("b/338288901")
207
- config = gemma2.get_fake_model_config_2b_for_test()
208
- model = gemma2.Gemma2(config)
209
- model.eval()
210
-
211
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
212
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
213
- tokens[0, :4] = idx
214
- input_pos = torch.arange(0, 10)
215
-
216
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
217
-
218
- # TODO: b/338288901 - re-enable test to check output tensors.
219
- skip_output_check = True
220
- if not skip_output_check:
221
- # TODO(talumbau, haoliang): debug numerical diff.
222
- self.assertTrue(
223
- model_coverage.compare_tflite_torch(
224
- edge_model,
225
- model,
226
- (tokens, input_pos),
227
- num_valid_inputs=1,
228
- atol=1e-2,
229
- rtol=1e-5,
230
- )
231
- )
232
-
233
- def test_phi2(self):
234
- self.skipTest("b/338288901")
235
- config = phi2.get_fake_model_config_for_test()
236
- pytorch_model = phi2.Phi2(config).eval()
237
-
238
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
239
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
240
- tokens[0, :4] = idx
241
- input_pos = torch.arange(0, 10)
242
-
243
- edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
244
-
245
- # TODO: b/338288901 - re-enable test to check output tensors.
246
- skip_output_check = True
247
- if not skip_output_check:
248
- self.assertTrue(
249
- model_coverage.compare_tflite_torch(
250
- edge_model,
251
- pytorch_model,
252
- (tokens, input_pos),
253
- num_valid_inputs=1,
254
- atol=1e-5,
255
- rtol=1e-5,
256
- )
257
- )
258
-
259
185
 
260
186
  if __name__ == "__main__":
261
187
  googletest.main()
@@ -0,0 +1,139 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Testing model conversion for a few gen-ai models.
16
+ import copy
17
+
18
+ import ai_edge_torch
19
+ from ai_edge_torch import config as ai_edge_config
20
+ from ai_edge_torch.generative.examples.gemma import gemma, gemma2
21
+ from ai_edge_torch.generative.examples.phi2 import phi2
22
+ from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
23
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
+ from ai_edge_torch.testing import model_coverage
25
+ import numpy as np
26
+ import torch
27
+
28
+ from absl.testing import absltest as googletest
29
+ from tensorflow.lite.python import interpreter
30
+
31
+
32
+ class TestModelConversion(googletest.TestCase):
33
+ """Unit tests that check for model conversion and correctness."""
34
+
35
+ def setUp(self):
36
+ super().setUp()
37
+ # Builder function for an Interpreter that supports custom ops.
38
+ self._interpreter_builder = (
39
+ lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
40
+ custom_op_registerers=["GenAIOpsRegisterer"],
41
+ model_content=tflite_model,
42
+ experimental_default_delegate_latest_features=True,
43
+ )
44
+ )
45
+
46
+ @googletest.skipIf(
47
+ ai_edge_config.Config.use_torch_xla,
48
+ reason="tests with custom ops are not supported on oss",
49
+ )
50
+ def test_gemma(self):
51
+ config = gemma.get_fake_model_config()
52
+ model = gemma.Gemma(config)
53
+
54
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
55
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
56
+ tokens[0, :4] = idx
57
+ input_pos = torch.arange(0, 10)
58
+
59
+ edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
60
+ edge_model.set_interpreter_builder(
61
+ self._interpreter_builder(edge_model.tflite_model())
62
+ )
63
+
64
+ self.assertTrue(
65
+ model_coverage.compare_tflite_torch(
66
+ edge_model,
67
+ model,
68
+ (tokens, input_pos),
69
+ num_valid_inputs=1,
70
+ atol=1e-2,
71
+ rtol=1e-5,
72
+ )
73
+ )
74
+
75
+ @googletest.skipIf(
76
+ ai_edge_config.Config.use_torch_xla,
77
+ reason="tests with custom ops are not supported on oss",
78
+ )
79
+ def test_gemma2(self):
80
+ config = gemma2.get_fake_model_config()
81
+ model = gemma2.Gemma2(config)
82
+ model.eval()
83
+
84
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
85
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
+ tokens[0, :4] = idx
87
+ input_pos = torch.arange(0, 10)
88
+
89
+ edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
90
+ edge_model.set_interpreter_builder(
91
+ self._interpreter_builder(edge_model.tflite_model())
92
+ )
93
+
94
+ # TODO(b/362840003): debug numerical diff.
95
+ skip_output_check = True
96
+ if not skip_output_check:
97
+ self.assertTrue(
98
+ model_coverage.compare_tflite_torch(
99
+ edge_model,
100
+ model,
101
+ (tokens, input_pos),
102
+ num_valid_inputs=1,
103
+ atol=1e-2,
104
+ rtol=1e-5,
105
+ )
106
+ )
107
+
108
+ @googletest.skipIf(
109
+ ai_edge_config.Config.use_torch_xla,
110
+ reason="tests with custom ops are not supported on oss",
111
+ )
112
+ def test_phi2(self):
113
+ config = phi2.get_fake_model_config()
114
+ pytorch_model = phi2.Phi2(config).eval()
115
+
116
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
117
+ tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
118
+ tokens[0, :4] = idx
119
+ input_pos = torch.arange(0, 10)
120
+
121
+ edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
122
+ edge_model.set_interpreter_builder(
123
+ self._interpreter_builder(edge_model.tflite_model())
124
+ )
125
+
126
+ self.assertTrue(
127
+ model_coverage.compare_tflite_torch(
128
+ edge_model,
129
+ pytorch_model,
130
+ (tokens, input_pos),
131
+ num_valid_inputs=1,
132
+ atol=1e-3,
133
+ rtol=1e-3,
134
+ )
135
+ )
136
+
137
+
138
+ if __name__ == "__main__":
139
+ googletest.main()
ai_edge_torch/model.py CHANGED
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
 
23
23
  import abc
24
24
  import re
25
+ from typing import Callable
25
26
 
26
27
  import numpy.typing as npt
27
28
  import tensorflow as tf
@@ -64,6 +65,24 @@ class TfLiteModel(Model):
64
65
  tflite_model: A TFlite serialized object.
65
66
  """
66
67
  self._tflite_model = tflite_model
68
+ self._interpreter_builder = lambda: tf.lite.Interpreter(
69
+ model_content=self._tflite_model,
70
+ experimental_default_delegate_latest_features=True,
71
+ )
72
+
73
+ def tflite_model(self) -> bytes:
74
+ """Returns the wrapped tflite model."""
75
+ return self._tflite_model
76
+
77
+ def set_interpreter_builder(
78
+ self, builder: Callable[[], tf.lite.Interpreter]
79
+ ) -> None:
80
+ """Sets a custom interpreter builder.
81
+
82
+ Args:
83
+ builder: A function that returns a `tf.lite.Interpreter` or its subclass.
84
+ """
85
+ self._interpreter_builder = builder
67
86
 
68
87
  def __call__(
69
88
  self,
@@ -80,10 +99,7 @@ class TfLiteModel(Model):
80
99
  signature_name: The name of the signature to be used for inference. The
81
100
  default signature is used if not provided.
82
101
  """
83
- interpreter = tf.lite.Interpreter(
84
- model_content=self._tflite_model,
85
- experimental_default_delegate_latest_features=True,
86
- )
102
+ interpreter = self._interpreter_builder()
87
103
  interpreter.allocate_tensors()
88
104
 
89
105
  signature_list = interpreter.get_signature_list()
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240829"
16
+ __version__ = "0.3.0.dev20240831"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240829
3
+ Version: 0.3.0.dev20240831
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,8 +1,8 @@
1
1
  ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
- ai_edge_torch/model.py,sha256=7tox6sdFIlCYPLDYpjFcD8cPTSivURCL_VV6-Dt5Sfc,4910
5
- ai_edge_torch/version.py,sha256=OF9oSdUOGcmdEp2HSZmEIeCPlRhL3cpviHc_dExhcX8,706
4
+ ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
+ ai_edge_torch/version.py,sha256=j78jEAdvuHPxuAOpjMJFUnPUQA0hPynGaNAPjNtw2SI,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -42,7 +42,7 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
42
42
  ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
43
  ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
44
  ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=8313wSsddvuxZ5ZYVdaITBV2FF1k22dcCujnq0UZvKs,6699
45
+ ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=EdElPCDLYxnNvkPMJkE3WKvESze1ehgShEk2NnbrXLg,7527
46
46
  ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
48
  ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
@@ -52,11 +52,11 @@ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=z
52
52
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
53
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
54
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=q9Zil66EvRKrSpLVQHxKHu_8NL0HAgY2FbtThoTZVUY,8226
55
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=pzD9dYUYg8E6fFACh-8B8G9NHFXOVEWBjf5aDeipU2s,7202
56
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=ypd6uBb4FgDpuWm_w8JNYBAf4eFxWbYccs8vCgBhi-I,9374
57
57
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
59
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=91mWxEtKgDtUhCAewWNwH_UOOCzy6tPdf6LNRlxZhrc,6700
60
60
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
62
62
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -82,7 +82,7 @@ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.p
82
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
83
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
84
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
85
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mXXFYJfo8yegSOFOndCR0oYxFPchYb9vTJ4ThXGIFLU,5940
85
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=JmwU1sniO37vnCFc8dklbd-0ofTZK0PaBv_Ksn1Vq6M,5930
86
86
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
87
87
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
88
88
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -100,7 +100,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3Xqsgr
100
100
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
101
101
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
102
102
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
103
- ai_edge_torch/generative/quantize/example.py,sha256=mqi3zFUp4w198DGnRkmZCWUZdUXTkvg1_tdTdOk9IkA,1535
103
+ ai_edge_torch/generative/quantize/example.py,sha256=Bmc-WowIJIfDgt84CNw2LhyLRi7SFcw8BQEu4byTKJU,1523
104
104
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
105
105
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
106
106
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
@@ -111,7 +111,8 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha
111
111
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
112
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
113
113
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
114
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=-RBTQSERP4szm8s8s_WRmGF3mWZA5E2w2QNtl2MqORw,8475
114
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=wQLVjMnKHBCVCU_I-xAUZvlOFoDiwYwKQDvCZ2mjtOM,6193
115
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
115
116
  ai_edge_torch/generative/test/test_quantize.py,sha256=JEsk9SAkHK0SFm44K_quISc5yBBS6yvtBP1MDyFHdFw,5344
116
117
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
117
118
  ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
@@ -161,8 +162,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
162
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
163
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
164
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240829.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240829.dist-info/METADATA,sha256=LrexCNdY177vrp17WaGa53bxHH9vuZXT64O5by4HE6Y,1878
166
- ai_edge_torch_nightly-0.3.0.dev20240829.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240829.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240829.dist-info/RECORD,,
165
+ ai_edge_torch_nightly-0.3.0.dev20240831.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
166
+ ai_edge_torch_nightly-0.3.0.dev20240831.dist-info/METADATA,sha256=yQLF91cZImFohJbXO693TGcFtA5GEWuaGYAJZKm9oPE,1878
167
+ ai_edge_torch_nightly-0.3.0.dev20240831.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
168
+ ai_edge_torch_nightly-0.3.0.dev20240831.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
169
+ ai_edge_torch_nightly-0.3.0.dev20240831.dist-info/RECORD,,