ai-edge-torch-nightly 0.2.0.dev20240605__py3-none-any.whl → 0.2.0.dev20240607__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/test/test_model_conversion.py +90 -80
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240607.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240607.dist-info}/RECORD +6 -6
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240607.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240607.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240605.dist-info → ai_edge_torch_nightly-0.2.0.dev20240607.dist-info}/top_level.txt +0 -0
|
@@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
|
|
|
33
33
|
"""Unit tests that check for model conversion and correctness."""
|
|
34
34
|
|
|
35
35
|
def test_toy_model_with_kv_cache(self):
|
|
36
|
-
self.skipTest("b/338288901")
|
|
37
36
|
config = toy_model_with_kv_cache.get_model_config()
|
|
38
37
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
39
38
|
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
@@ -42,19 +41,21 @@ class TestModelConversion(unittest.TestCase):
|
|
|
42
41
|
|
|
43
42
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
44
43
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
44
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
45
|
+
skip_output_check = True
|
|
46
|
+
if skip_output_check is False:
|
|
47
|
+
self.assertTrue(
|
|
48
|
+
model_coverage.compare_tflite_torch(
|
|
49
|
+
edge_model,
|
|
50
|
+
pytorch_model,
|
|
51
|
+
(idx, input_pos),
|
|
52
|
+
num_valid_inputs=1,
|
|
53
|
+
atol=1e-5,
|
|
54
|
+
rtol=1e-5,
|
|
55
|
+
)
|
|
56
|
+
)
|
|
55
57
|
|
|
56
58
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
57
|
-
self.skipTest("b/338288901")
|
|
58
59
|
config = toy_model_with_kv_cache.get_model_config()
|
|
59
60
|
config.enable_hlfb = True
|
|
60
61
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
@@ -64,16 +65,19 @@ class TestModelConversion(unittest.TestCase):
|
|
|
64
65
|
|
|
65
66
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
66
67
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
68
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
69
|
+
skip_output_check = True
|
|
70
|
+
if skip_output_check is False:
|
|
71
|
+
self.assertTrue(
|
|
72
|
+
model_coverage.compare_tflite_torch(
|
|
73
|
+
edge_model,
|
|
74
|
+
pytorch_model,
|
|
75
|
+
(idx, input_pos),
|
|
76
|
+
num_valid_inputs=1,
|
|
77
|
+
atol=1e-5,
|
|
78
|
+
rtol=1e-5,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
77
81
|
|
|
78
82
|
def test_tiny_llama(self):
|
|
79
83
|
self.skipTest("b/338288901")
|
|
@@ -87,19 +91,21 @@ class TestModelConversion(unittest.TestCase):
|
|
|
87
91
|
|
|
88
92
|
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
89
93
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
94
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
95
|
+
skip_output_check = True
|
|
96
|
+
if skip_output_check is False:
|
|
97
|
+
self.assertTrue(
|
|
98
|
+
model_coverage.compare_tflite_torch(
|
|
99
|
+
edge_model,
|
|
100
|
+
pytorch_model,
|
|
101
|
+
(tokens, input_pos),
|
|
102
|
+
num_valid_inputs=1,
|
|
103
|
+
atol=1e-5,
|
|
104
|
+
rtol=1e-5,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
100
107
|
|
|
101
108
|
def test_tiny_llama_multisig(self):
|
|
102
|
-
self.skipTest("b/338288901")
|
|
103
109
|
config = tiny_llama.get_fake_model_config_for_test()
|
|
104
110
|
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
105
111
|
|
|
@@ -122,32 +128,30 @@ class TestModelConversion(unittest.TestCase):
|
|
|
122
128
|
.convert()
|
|
123
129
|
)
|
|
124
130
|
|
|
125
|
-
#
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
150
|
-
)
|
|
131
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
132
|
+
skip_output_check = True
|
|
133
|
+
if skip_output_check is False:
|
|
134
|
+
copied_model = copy.deepcopy(pytorch_model)
|
|
135
|
+
|
|
136
|
+
self.assertTrue(
|
|
137
|
+
model_coverage.compare_tflite_torch(
|
|
138
|
+
edge_model,
|
|
139
|
+
pytorch_model,
|
|
140
|
+
(prefill_tokens, prefill_input_pos),
|
|
141
|
+
signature_name="prefill",
|
|
142
|
+
num_valid_inputs=1,
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.assertTrue(
|
|
147
|
+
model_coverage.compare_tflite_torch(
|
|
148
|
+
edge_model,
|
|
149
|
+
copied_model,
|
|
150
|
+
(decode_token, decode_input_pos),
|
|
151
|
+
signature_name="decode",
|
|
152
|
+
num_valid_inputs=1,
|
|
153
|
+
)
|
|
154
|
+
)
|
|
151
155
|
|
|
152
156
|
def test_gemma(self):
|
|
153
157
|
self.skipTest("b/338288901")
|
|
@@ -161,17 +165,20 @@ class TestModelConversion(unittest.TestCase):
|
|
|
161
165
|
|
|
162
166
|
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
163
167
|
|
|
164
|
-
# TODO(
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
168
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
169
|
+
skip_output_check = True
|
|
170
|
+
if skip_output_check is False:
|
|
171
|
+
# TODO(talumbau, haoliang): debug numerical diff.
|
|
172
|
+
self.assertTrue(
|
|
173
|
+
model_coverage.compare_tflite_torch(
|
|
174
|
+
edge_model,
|
|
175
|
+
model,
|
|
176
|
+
(tokens, input_pos),
|
|
177
|
+
num_valid_inputs=1,
|
|
178
|
+
atol=1e-2,
|
|
179
|
+
rtol=1e-5,
|
|
180
|
+
)
|
|
181
|
+
)
|
|
175
182
|
|
|
176
183
|
def test_phi2(self):
|
|
177
184
|
self.skipTest("b/338288901")
|
|
@@ -185,16 +192,19 @@ class TestModelConversion(unittest.TestCase):
|
|
|
185
192
|
|
|
186
193
|
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
187
194
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
195
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
196
|
+
skip_output_check = True
|
|
197
|
+
if skip_output_check is False:
|
|
198
|
+
self.assertTrue(
|
|
199
|
+
model_coverage.compare_tflite_torch(
|
|
200
|
+
edge_model,
|
|
201
|
+
pytorch_model,
|
|
202
|
+
(tokens, input_pos),
|
|
203
|
+
num_valid_inputs=1,
|
|
204
|
+
atol=1e-5,
|
|
205
|
+
rtol=1e-5,
|
|
206
|
+
)
|
|
207
|
+
)
|
|
198
208
|
|
|
199
209
|
|
|
200
210
|
if __name__ == "__main__":
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240607
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -83,7 +83,7 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2w
|
|
|
83
83
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
|
|
84
84
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
85
85
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
86
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
86
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
|
|
87
87
|
ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
|
|
88
88
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
89
89
|
ai_edge_torch/generative/utilities/loader.py,sha256=r-_hSanSjLZ_YXFpZUb0Up94u5F8JHp70Vf2nlONPSg,11269
|
|
@@ -102,8 +102,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
|
|
|
102
102
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
103
103
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
104
104
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
105
|
-
ai_edge_torch_nightly-0.2.0.
|
|
106
|
-
ai_edge_torch_nightly-0.2.0.
|
|
107
|
-
ai_edge_torch_nightly-0.2.0.
|
|
108
|
-
ai_edge_torch_nightly-0.2.0.
|
|
109
|
-
ai_edge_torch_nightly-0.2.0.
|
|
105
|
+
ai_edge_torch_nightly-0.2.0.dev20240607.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
106
|
+
ai_edge_torch_nightly-0.2.0.dev20240607.dist-info/METADATA,sha256=wXfrp5POJLq610NgMNlMi-WaC2AHCc0u9d4z2t6KiMg,1748
|
|
107
|
+
ai_edge_torch_nightly-0.2.0.dev20240607.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
108
|
+
ai_edge_torch_nightly-0.2.0.dev20240607.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
109
|
+
ai_edge_torch_nightly-0.2.0.dev20240607.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|