ai-edge-torch-nightly 0.2.0.dev20240605__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
33
33
  """Unit tests that check for model conversion and correctness."""
34
34
 
35
35
  def test_toy_model_with_kv_cache(self):
36
- self.skipTest("b/338288901")
37
36
  config = toy_model_with_kv_cache.get_model_config()
38
37
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
39
38
  idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
@@ -42,19 +41,21 @@ class TestModelConversion(unittest.TestCase):
42
41
 
43
42
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
44
43
 
45
- self.assertTrue(
46
- model_coverage.compare_tflite_torch(
47
- edge_model,
48
- pytorch_model,
49
- (idx, input_pos),
50
- num_valid_inputs=1,
51
- atol=1e-5,
52
- rtol=1e-5,
53
- )
54
- )
44
+ # TODO(b/338288901): re-enable test to check output tensors.
45
+ skip_output_check = True
46
+ if skip_output_check is False:
47
+ self.assertTrue(
48
+ model_coverage.compare_tflite_torch(
49
+ edge_model,
50
+ pytorch_model,
51
+ (idx, input_pos),
52
+ num_valid_inputs=1,
53
+ atol=1e-5,
54
+ rtol=1e-5,
55
+ )
56
+ )
55
57
 
56
58
  def test_toy_model_with_kv_cache_with_hlfb(self):
57
- self.skipTest("b/338288901")
58
59
  config = toy_model_with_kv_cache.get_model_config()
59
60
  config.enable_hlfb = True
60
61
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
@@ -64,16 +65,19 @@ class TestModelConversion(unittest.TestCase):
64
65
 
65
66
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
66
67
 
67
- self.assertTrue(
68
- model_coverage.compare_tflite_torch(
69
- edge_model,
70
- pytorch_model,
71
- (idx, input_pos),
72
- num_valid_inputs=1,
73
- atol=1e-5,
74
- rtol=1e-5,
75
- )
76
- )
68
+ # TODO(b/338288901): re-enable test to check output tensors.
69
+ skip_output_check = True
70
+ if skip_output_check is False:
71
+ self.assertTrue(
72
+ model_coverage.compare_tflite_torch(
73
+ edge_model,
74
+ pytorch_model,
75
+ (idx, input_pos),
76
+ num_valid_inputs=1,
77
+ atol=1e-5,
78
+ rtol=1e-5,
79
+ )
80
+ )
77
81
 
78
82
  def test_tiny_llama(self):
79
83
  self.skipTest("b/338288901")
@@ -87,19 +91,21 @@ class TestModelConversion(unittest.TestCase):
87
91
 
88
92
  edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
89
93
 
90
- self.assertTrue(
91
- model_coverage.compare_tflite_torch(
92
- edge_model,
93
- pytorch_model,
94
- (tokens, input_pos),
95
- num_valid_inputs=1,
96
- atol=1e-5,
97
- rtol=1e-5,
98
- )
99
- )
94
+ # TODO(b/338288901): re-enable test to check output tensors.
95
+ skip_output_check = True
96
+ if skip_output_check is False:
97
+ self.assertTrue(
98
+ model_coverage.compare_tflite_torch(
99
+ edge_model,
100
+ pytorch_model,
101
+ (tokens, input_pos),
102
+ num_valid_inputs=1,
103
+ atol=1e-5,
104
+ rtol=1e-5,
105
+ )
106
+ )
100
107
 
101
108
  def test_tiny_llama_multisig(self):
102
- self.skipTest("b/338288901")
103
109
  config = tiny_llama.get_fake_model_config_for_test()
104
110
  pytorch_model = tiny_llama.TinyLLamma(config)
105
111
 
@@ -122,32 +128,30 @@ class TestModelConversion(unittest.TestCase):
122
128
  .convert()
123
129
  )
124
130
 
125
- # For the pytorch model, the KV cache is a persistent state internal to the model, and it
126
- # will be shared for prefill and decode. However, for tflite, currently we can't share
127
- # kv-cache between the two signatures. prefill will change the content in kv-cache,
128
- # but it won't be readable by the decode tflite model. This means the output of running `decode` after
129
- # running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
130
- copied_model = copy.deepcopy(pytorch_model)
131
-
132
- self.assertTrue(
133
- model_coverage.compare_tflite_torch(
134
- edge_model,
135
- pytorch_model,
136
- (prefill_tokens, prefill_input_pos),
137
- signature_name="prefill",
138
- num_valid_inputs=1,
139
- )
140
- )
141
-
142
- self.assertTrue(
143
- model_coverage.compare_tflite_torch(
144
- edge_model,
145
- copied_model,
146
- (decode_token, decode_input_pos),
147
- signature_name="decode",
148
- num_valid_inputs=1,
149
- )
150
- )
131
+ # TODO(b/338288901): re-enable test to check output tensors.
132
+ skip_output_check = True
133
+ if skip_output_check is False:
134
+ copied_model = copy.deepcopy(pytorch_model)
135
+
136
+ self.assertTrue(
137
+ model_coverage.compare_tflite_torch(
138
+ edge_model,
139
+ pytorch_model,
140
+ (prefill_tokens, prefill_input_pos),
141
+ signature_name="prefill",
142
+ num_valid_inputs=1,
143
+ )
144
+ )
145
+
146
+ self.assertTrue(
147
+ model_coverage.compare_tflite_torch(
148
+ edge_model,
149
+ copied_model,
150
+ (decode_token, decode_input_pos),
151
+ signature_name="decode",
152
+ num_valid_inputs=1,
153
+ )
154
+ )
151
155
 
152
156
  def test_gemma(self):
153
157
  self.skipTest("b/338288901")
@@ -161,17 +165,20 @@ class TestModelConversion(unittest.TestCase):
161
165
 
162
166
  edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
163
167
 
164
- # TODO(talumbau, haoliang): debug numerical diff.
165
- self.assertTrue(
166
- model_coverage.compare_tflite_torch(
167
- edge_model,
168
- model,
169
- (tokens, input_pos),
170
- num_valid_inputs=1,
171
- atol=1e-2,
172
- rtol=1e-5,
173
- )
174
- )
168
+ # TODO(b/338288901): re-enable test to check output tensors.
169
+ skip_output_check = True
170
+ if skip_output_check is False:
171
+ # TODO(talumbau, haoliang): debug numerical diff.
172
+ self.assertTrue(
173
+ model_coverage.compare_tflite_torch(
174
+ edge_model,
175
+ model,
176
+ (tokens, input_pos),
177
+ num_valid_inputs=1,
178
+ atol=1e-2,
179
+ rtol=1e-5,
180
+ )
181
+ )
175
182
 
176
183
  def test_phi2(self):
177
184
  self.skipTest("b/338288901")
@@ -185,16 +192,19 @@ class TestModelConversion(unittest.TestCase):
185
192
 
186
193
  edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
187
194
 
188
- self.assertTrue(
189
- model_coverage.compare_tflite_torch(
190
- edge_model,
191
- pytorch_model,
192
- (tokens, input_pos),
193
- num_valid_inputs=1,
194
- atol=1e-5,
195
- rtol=1e-5,
196
- )
197
- )
195
+ # TODO(b/338288901): re-enable test to check output tensors.
196
+ skip_output_check = True
197
+ if skip_output_check is False:
198
+ self.assertTrue(
199
+ model_coverage.compare_tflite_torch(
200
+ edge_model,
201
+ pytorch_model,
202
+ (tokens, input_pos),
203
+ num_valid_inputs=1,
204
+ atol=1e-5,
205
+ rtol=1e-5,
206
+ )
207
+ )
198
208
 
199
209
 
200
210
  if __name__ == "__main__":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240605
3
+ Version: 0.2.0.dev20240606
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -83,7 +83,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2w
83
83
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
84
84
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
85
85
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
86
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
86
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
87
87
  ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
88
88
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
89
89
  ai_edge_torch/generative/utilities/loader.py,sha256=r-_hSanSjLZ_YXFpZUb0Up94u5F8JHp70Vf2nlONPSg,11269
@@ -102,8 +102,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
102
102
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
103
103
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
104
104
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
105
- ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
106
- ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/METADATA,sha256=GJzwmKkM4T0H-vTvMyoxiD80WfppEpE_sd2Ip4aSbgM,1748
107
- ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
108
- ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
109
- ai_edge_torch_nightly-0.2.0.dev20240605.dist-info/RECORD,,
105
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
106
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/METADATA,sha256=2yFjQFvs93hoppwWWSJi-B9HC2n-h1s2rzjsyGXU1zI,1748
107
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
108
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
109
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/RECORD,,