ai-edge-torch-nightly 0.3.0.dev20240829__py3-none-any.whl → 0.3.0.dev20240901__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +33 -4
- ai_edge_torch/generative/examples/gemma/gemma.py +33 -3
- ai_edge_torch/generative/examples/gemma/gemma2.py +41 -3
- ai_edge_torch/generative/examples/phi2/phi2.py +30 -3
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -2
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +86 -160
- ai_edge_torch/generative/test/test_model_conversion_large.py +139 -0
- ai_edge_torch/model.py +20 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240829.dist-info → ai_edge_torch_nightly-0.3.0.dev20240901.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240829.dist-info → ai_edge_torch_nightly-0.3.0.dev20240901.dist-info}/RECORD +15 -14
- {ai_edge_torch_nightly-0.3.0.dev20240829.dist-info → ai_edge_torch_nightly-0.3.0.dev20240901.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240829.dist-info → ai_edge_torch_nightly-0.3.0.dev20240901.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240829.dist-info → ai_edge_torch_nightly-0.3.0.dev20240901.dist-info}/top_level.txt +0 -0
|
@@ -159,15 +159,44 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
159
159
|
return config
|
|
160
160
|
|
|
161
161
|
|
|
162
|
-
def
|
|
163
|
-
|
|
164
|
-
|
|
162
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
163
|
+
attn_config = cfg.AttentionConfig(
|
|
164
|
+
num_heads=8,
|
|
165
|
+
head_dim=256,
|
|
166
|
+
num_query_groups=1,
|
|
167
|
+
rotary_percentage=1.0,
|
|
168
|
+
)
|
|
169
|
+
ff_config = cfg.FeedForwardConfig(
|
|
170
|
+
type=cfg.FeedForwardType.GATED,
|
|
171
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
172
|
+
intermediate_size=128,
|
|
173
|
+
)
|
|
174
|
+
norm_config = cfg.NormalizationConfig(
|
|
175
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
176
|
+
epsilon=1e-6,
|
|
177
|
+
zero_centered=True,
|
|
178
|
+
)
|
|
179
|
+
config = cfg.ModelConfig(
|
|
180
|
+
vocab_size=128,
|
|
181
|
+
num_layers=2,
|
|
182
|
+
max_seq_len=2 * kv_cache_max_len,
|
|
183
|
+
embedding_dim=2048,
|
|
184
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
185
|
+
attn_config=attn_config,
|
|
186
|
+
ff_config=ff_config,
|
|
187
|
+
pre_attention_norm_config=norm_config,
|
|
188
|
+
post_attention_norm_config=norm_config,
|
|
189
|
+
final_norm_config=norm_config,
|
|
190
|
+
parallel_residual=False,
|
|
191
|
+
lm_head_use_bias=False,
|
|
192
|
+
enable_hlfb=True,
|
|
193
|
+
)
|
|
165
194
|
return config
|
|
166
195
|
|
|
167
196
|
|
|
168
197
|
def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
|
|
169
198
|
config = (
|
|
170
|
-
|
|
199
|
+
get_fake_model_config(**kwargs)
|
|
171
200
|
if test_model
|
|
172
201
|
else get_model_config_2b(**kwargs)
|
|
173
202
|
)
|
|
@@ -147,9 +147,39 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
147
147
|
return config
|
|
148
148
|
|
|
149
149
|
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
150
|
+
# TODO(b/363021962): Clean up this part to streamline fake model config generation.
|
|
151
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
152
|
+
attn_config = cfg.AttentionConfig(
|
|
153
|
+
num_heads=8,
|
|
154
|
+
head_dim=256,
|
|
155
|
+
num_query_groups=1,
|
|
156
|
+
rotary_percentage=1.0,
|
|
157
|
+
)
|
|
158
|
+
ff_config = cfg.FeedForwardConfig(
|
|
159
|
+
type=cfg.FeedForwardType.GATED,
|
|
160
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
161
|
+
intermediate_size=128,
|
|
162
|
+
)
|
|
163
|
+
norm_config = cfg.NormalizationConfig(
|
|
164
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
165
|
+
epsilon=1e-6,
|
|
166
|
+
zero_centered=True,
|
|
167
|
+
)
|
|
168
|
+
config = cfg.ModelConfig(
|
|
169
|
+
vocab_size=128,
|
|
170
|
+
num_layers=2,
|
|
171
|
+
max_seq_len=2 * kv_cache_max_len,
|
|
172
|
+
embedding_dim=2048,
|
|
173
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
174
|
+
attn_config=attn_config,
|
|
175
|
+
ff_config=ff_config,
|
|
176
|
+
pre_attention_norm_config=norm_config,
|
|
177
|
+
post_attention_norm_config=norm_config,
|
|
178
|
+
final_norm_config=norm_config,
|
|
179
|
+
parallel_residual=False,
|
|
180
|
+
lm_head_use_bias=False,
|
|
181
|
+
enable_hlfb=True,
|
|
182
|
+
)
|
|
153
183
|
return config
|
|
154
184
|
|
|
155
185
|
|
|
@@ -209,9 +209,47 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
209
209
|
return config
|
|
210
210
|
|
|
211
211
|
|
|
212
|
-
def
|
|
213
|
-
|
|
214
|
-
|
|
212
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
213
|
+
attn_config = cfg.AttentionConfig(
|
|
214
|
+
num_heads=4,
|
|
215
|
+
head_dim=64,
|
|
216
|
+
num_query_groups=4,
|
|
217
|
+
rotary_percentage=1.0,
|
|
218
|
+
qkv_transpose_before_split=True,
|
|
219
|
+
logit_softcap=50.0,
|
|
220
|
+
sliding_window_size=64,
|
|
221
|
+
attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
|
|
222
|
+
* 13,
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
norm_config = cfg.NormalizationConfig(
|
|
226
|
+
type=cfg.NormalizationType.RMS_NORM,
|
|
227
|
+
epsilon=1e-6,
|
|
228
|
+
zero_centered=True,
|
|
229
|
+
)
|
|
230
|
+
ff_config = cfg.FeedForwardConfig(
|
|
231
|
+
type=cfg.FeedForwardType.GATED,
|
|
232
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
233
|
+
intermediate_size=128,
|
|
234
|
+
pre_ff_norm_config=norm_config,
|
|
235
|
+
post_ff_norm_config=norm_config,
|
|
236
|
+
)
|
|
237
|
+
config = cfg.ModelConfig(
|
|
238
|
+
vocab_size=128,
|
|
239
|
+
num_layers=2,
|
|
240
|
+
max_seq_len=2 * kv_cache_max_len,
|
|
241
|
+
embedding_dim=128,
|
|
242
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
243
|
+
attn_config=attn_config,
|
|
244
|
+
ff_config=ff_config,
|
|
245
|
+
pre_attention_norm_config=norm_config,
|
|
246
|
+
post_attention_norm_config=norm_config,
|
|
247
|
+
final_norm_config=norm_config,
|
|
248
|
+
parallel_residual=False,
|
|
249
|
+
lm_head_use_bias=False,
|
|
250
|
+
enable_hlfb=True,
|
|
251
|
+
final_logit_softcap=30.0,
|
|
252
|
+
)
|
|
215
253
|
return config
|
|
216
254
|
|
|
217
255
|
|
|
@@ -139,9 +139,36 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
139
139
|
return config
|
|
140
140
|
|
|
141
141
|
|
|
142
|
-
def
|
|
143
|
-
|
|
144
|
-
|
|
142
|
+
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
143
|
+
attn_config = cfg.AttentionConfig(
|
|
144
|
+
num_heads=16,
|
|
145
|
+
head_dim=80,
|
|
146
|
+
num_query_groups=4,
|
|
147
|
+
rotary_percentage=0.4,
|
|
148
|
+
qkv_use_bias=True,
|
|
149
|
+
output_proj_use_bias=True,
|
|
150
|
+
)
|
|
151
|
+
ff_config = cfg.FeedForwardConfig(
|
|
152
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
|
153
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
|
154
|
+
intermediate_size=128,
|
|
155
|
+
use_bias=True,
|
|
156
|
+
)
|
|
157
|
+
norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
|
|
158
|
+
config = cfg.ModelConfig(
|
|
159
|
+
vocab_size=128,
|
|
160
|
+
num_layers=2,
|
|
161
|
+
max_seq_len=2 * kv_cache_max_len,
|
|
162
|
+
kv_cache_max_len=kv_cache_max_len,
|
|
163
|
+
embedding_dim=128,
|
|
164
|
+
attn_config=attn_config,
|
|
165
|
+
ff_config=ff_config,
|
|
166
|
+
pre_attention_norm_config=norm_config,
|
|
167
|
+
final_norm_config=norm_config,
|
|
168
|
+
parallel_residual=True,
|
|
169
|
+
lm_head_use_bias=True,
|
|
170
|
+
enable_hlfb=True,
|
|
171
|
+
)
|
|
145
172
|
return config
|
|
146
173
|
|
|
147
174
|
|
|
@@ -137,11 +137,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
137
137
|
return config
|
|
138
138
|
|
|
139
139
|
|
|
140
|
-
def
|
|
140
|
+
def get_fake_model_config() -> cfg.ModelConfig:
|
|
141
141
|
config = get_model_config()
|
|
142
142
|
config.vocab_size = 128
|
|
143
143
|
config.num_layers = 2
|
|
144
|
-
config.ff_config.intermediate_size =
|
|
144
|
+
config.ff_config.intermediate_size = 64
|
|
145
145
|
return config
|
|
146
146
|
|
|
147
147
|
|
|
@@ -22,7 +22,7 @@ import torch
|
|
|
22
22
|
|
|
23
23
|
def main():
|
|
24
24
|
# Build a PyTorch model as usual
|
|
25
|
-
config = gemma.
|
|
25
|
+
config = gemma.get_fake_model_config()
|
|
26
26
|
model = gemma.Gemma(config)
|
|
27
27
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
28
28
|
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import copy
|
|
17
17
|
|
|
18
18
|
import ai_edge_torch
|
|
19
|
+
from ai_edge_torch import config as ai_edge_config
|
|
19
20
|
from ai_edge_torch.generative.examples.gemma import gemma, gemma2
|
|
20
21
|
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
21
22
|
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
@@ -25,11 +26,27 @@ import numpy as np
|
|
|
25
26
|
import torch
|
|
26
27
|
|
|
27
28
|
from absl.testing import absltest as googletest
|
|
29
|
+
from tensorflow.lite.python import interpreter
|
|
28
30
|
|
|
29
31
|
|
|
30
32
|
class TestModelConversion(googletest.TestCase):
|
|
31
33
|
"""Unit tests that check for model conversion and correctness."""
|
|
32
34
|
|
|
35
|
+
def setUp(self):
|
|
36
|
+
super().setUp()
|
|
37
|
+
# Builder function for an Interpreter that supports custom ops.
|
|
38
|
+
self._interpreter_builder = (
|
|
39
|
+
lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
|
|
40
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
|
41
|
+
model_content=tflite_model,
|
|
42
|
+
experimental_default_delegate_latest_features=True,
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@googletest.skipIf(
|
|
47
|
+
ai_edge_config.Config.use_torch_xla,
|
|
48
|
+
reason="tests with custom ops are not supported on oss",
|
|
49
|
+
)
|
|
33
50
|
def test_toy_model_with_kv_cache(self):
|
|
34
51
|
config = toy_model_with_kv_cache.get_model_config()
|
|
35
52
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
|
@@ -38,22 +55,27 @@ class TestModelConversion(googletest.TestCase):
|
|
|
38
55
|
)
|
|
39
56
|
|
|
40
57
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
58
|
+
edge_model.set_interpreter_builder(
|
|
59
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
60
|
+
)
|
|
41
61
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
rtol=1e-5,
|
|
53
|
-
)
|
|
54
|
-
)
|
|
62
|
+
self.assertTrue(
|
|
63
|
+
model_coverage.compare_tflite_torch(
|
|
64
|
+
edge_model,
|
|
65
|
+
pytorch_model,
|
|
66
|
+
(idx, input_pos),
|
|
67
|
+
num_valid_inputs=1,
|
|
68
|
+
atol=1e-5,
|
|
69
|
+
rtol=1e-5,
|
|
70
|
+
)
|
|
71
|
+
)
|
|
55
72
|
|
|
73
|
+
@googletest.skipIf(
|
|
74
|
+
ai_edge_config.Config.use_torch_xla,
|
|
75
|
+
reason="tests with custom ops are not supported on oss",
|
|
76
|
+
)
|
|
56
77
|
def test_toy_model_with_multi_batches(self):
|
|
78
|
+
self.skipTest("b/362842043")
|
|
57
79
|
config = toy_model_with_kv_cache.get_model_config()
|
|
58
80
|
config.batch_size = 2
|
|
59
81
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
|
|
@@ -62,21 +84,25 @@ class TestModelConversion(googletest.TestCase):
|
|
|
62
84
|
)
|
|
63
85
|
|
|
64
86
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
87
|
+
edge_model.set_interpreter_builder(
|
|
88
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
89
|
+
)
|
|
65
90
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
rtol=1e-5,
|
|
77
|
-
)
|
|
78
|
-
)
|
|
91
|
+
self.assertTrue(
|
|
92
|
+
model_coverage.compare_tflite_torch(
|
|
93
|
+
edge_model,
|
|
94
|
+
pytorch_model,
|
|
95
|
+
(idx, input_pos),
|
|
96
|
+
num_valid_inputs=1,
|
|
97
|
+
atol=1e-5,
|
|
98
|
+
rtol=1e-5,
|
|
99
|
+
)
|
|
100
|
+
)
|
|
79
101
|
|
|
102
|
+
@googletest.skipIf(
|
|
103
|
+
ai_edge_config.Config.use_torch_xla,
|
|
104
|
+
reason="tests with custom ops are not supported on oss",
|
|
105
|
+
)
|
|
80
106
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
81
107
|
config = toy_model_with_kv_cache.get_model_config()
|
|
82
108
|
config.enable_hlfb = True
|
|
@@ -86,49 +112,27 @@ class TestModelConversion(googletest.TestCase):
|
|
|
86
112
|
)
|
|
87
113
|
|
|
88
114
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
115
|
+
edge_model.set_interpreter_builder(
|
|
116
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
117
|
+
)
|
|
89
118
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
rtol=1e-5,
|
|
101
|
-
)
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
def test_tiny_llama(self):
|
|
105
|
-
self.skipTest("b/338288901")
|
|
106
|
-
config = tiny_llama.get_fake_model_config_for_test()
|
|
107
|
-
pytorch_model = tiny_llama.TinyLLamma(config).eval()
|
|
108
|
-
|
|
109
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
110
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
111
|
-
tokens[0, :4] = idx
|
|
112
|
-
input_pos = torch.arange(0, 10)
|
|
113
|
-
|
|
114
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
115
|
-
|
|
116
|
-
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
117
|
-
skip_output_check = True
|
|
118
|
-
if not skip_output_check:
|
|
119
|
-
self.assertTrue(
|
|
120
|
-
model_coverage.compare_tflite_torch(
|
|
121
|
-
edge_model,
|
|
122
|
-
pytorch_model,
|
|
123
|
-
(tokens, input_pos),
|
|
124
|
-
num_valid_inputs=1,
|
|
125
|
-
atol=1e-5,
|
|
126
|
-
rtol=1e-5,
|
|
127
|
-
)
|
|
128
|
-
)
|
|
119
|
+
self.assertTrue(
|
|
120
|
+
model_coverage.compare_tflite_torch(
|
|
121
|
+
edge_model,
|
|
122
|
+
pytorch_model,
|
|
123
|
+
(idx, input_pos),
|
|
124
|
+
num_valid_inputs=1,
|
|
125
|
+
atol=1e-5,
|
|
126
|
+
rtol=1e-5,
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
129
|
|
|
130
|
+
@googletest.skipIf(
|
|
131
|
+
ai_edge_config.Config.use_torch_xla,
|
|
132
|
+
reason="tests with custom ops are not supported on oss",
|
|
133
|
+
)
|
|
130
134
|
def test_tiny_llama_multisig(self):
|
|
131
|
-
config = tiny_llama.
|
|
135
|
+
config = tiny_llama.get_fake_model_config()
|
|
132
136
|
pytorch_model = tiny_llama.TinyLLamma(config).eval()
|
|
133
137
|
|
|
134
138
|
# prefill
|
|
@@ -149,22 +153,25 @@ class TestModelConversion(googletest.TestCase):
|
|
|
149
153
|
.signature("decode", pytorch_model, (decode_token, decode_input_pos))
|
|
150
154
|
.convert()
|
|
151
155
|
)
|
|
156
|
+
edge_model.set_interpreter_builder(
|
|
157
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
158
|
+
)
|
|
152
159
|
|
|
153
|
-
|
|
154
|
-
skip_output_check = True
|
|
155
|
-
if not skip_output_check:
|
|
156
|
-
copied_model = copy.deepcopy(pytorch_model)
|
|
160
|
+
copied_model = copy.deepcopy(pytorch_model)
|
|
157
161
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
162
|
+
self.assertTrue(
|
|
163
|
+
model_coverage.compare_tflite_torch(
|
|
164
|
+
edge_model,
|
|
165
|
+
pytorch_model,
|
|
166
|
+
(prefill_tokens, prefill_input_pos),
|
|
167
|
+
signature_name="prefill",
|
|
168
|
+
num_valid_inputs=1,
|
|
169
|
+
)
|
|
170
|
+
)
|
|
167
171
|
|
|
172
|
+
# TODO(b/362840003): figure why this decode output has big numerical diff.
|
|
173
|
+
skip_output_check = True
|
|
174
|
+
if not skip_output_check:
|
|
168
175
|
self.assertTrue(
|
|
169
176
|
model_coverage.compare_tflite_torch(
|
|
170
177
|
edge_model,
|
|
@@ -175,87 +182,6 @@ class TestModelConversion(googletest.TestCase):
|
|
|
175
182
|
)
|
|
176
183
|
)
|
|
177
184
|
|
|
178
|
-
def test_gemma(self):
|
|
179
|
-
self.skipTest("b/338288901")
|
|
180
|
-
config = gemma.get_fake_model_config_2b_for_test()
|
|
181
|
-
model = gemma.Gemma(config)
|
|
182
|
-
|
|
183
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
184
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
185
|
-
tokens[0, :4] = idx
|
|
186
|
-
input_pos = torch.arange(0, 10)
|
|
187
|
-
|
|
188
|
-
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
189
|
-
|
|
190
|
-
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
191
|
-
skip_output_check = True
|
|
192
|
-
if not skip_output_check:
|
|
193
|
-
# TODO(talumbau, haoliang): debug numerical diff.
|
|
194
|
-
self.assertTrue(
|
|
195
|
-
model_coverage.compare_tflite_torch(
|
|
196
|
-
edge_model,
|
|
197
|
-
model,
|
|
198
|
-
(tokens, input_pos),
|
|
199
|
-
num_valid_inputs=1,
|
|
200
|
-
atol=1e-2,
|
|
201
|
-
rtol=1e-5,
|
|
202
|
-
)
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
def test_gemma2(self):
|
|
206
|
-
self.skipTest("b/338288901")
|
|
207
|
-
config = gemma2.get_fake_model_config_2b_for_test()
|
|
208
|
-
model = gemma2.Gemma2(config)
|
|
209
|
-
model.eval()
|
|
210
|
-
|
|
211
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
212
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
213
|
-
tokens[0, :4] = idx
|
|
214
|
-
input_pos = torch.arange(0, 10)
|
|
215
|
-
|
|
216
|
-
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
217
|
-
|
|
218
|
-
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
219
|
-
skip_output_check = True
|
|
220
|
-
if not skip_output_check:
|
|
221
|
-
# TODO(talumbau, haoliang): debug numerical diff.
|
|
222
|
-
self.assertTrue(
|
|
223
|
-
model_coverage.compare_tflite_torch(
|
|
224
|
-
edge_model,
|
|
225
|
-
model,
|
|
226
|
-
(tokens, input_pos),
|
|
227
|
-
num_valid_inputs=1,
|
|
228
|
-
atol=1e-2,
|
|
229
|
-
rtol=1e-5,
|
|
230
|
-
)
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
def test_phi2(self):
|
|
234
|
-
self.skipTest("b/338288901")
|
|
235
|
-
config = phi2.get_fake_model_config_for_test()
|
|
236
|
-
pytorch_model = phi2.Phi2(config).eval()
|
|
237
|
-
|
|
238
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
239
|
-
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
240
|
-
tokens[0, :4] = idx
|
|
241
|
-
input_pos = torch.arange(0, 10)
|
|
242
|
-
|
|
243
|
-
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
244
|
-
|
|
245
|
-
# TODO: b/338288901 - re-enable test to check output tensors.
|
|
246
|
-
skip_output_check = True
|
|
247
|
-
if not skip_output_check:
|
|
248
|
-
self.assertTrue(
|
|
249
|
-
model_coverage.compare_tflite_torch(
|
|
250
|
-
edge_model,
|
|
251
|
-
pytorch_model,
|
|
252
|
-
(tokens, input_pos),
|
|
253
|
-
num_valid_inputs=1,
|
|
254
|
-
atol=1e-5,
|
|
255
|
-
rtol=1e-5,
|
|
256
|
-
)
|
|
257
|
-
)
|
|
258
|
-
|
|
259
185
|
|
|
260
186
|
if __name__ == "__main__":
|
|
261
187
|
googletest.main()
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
# Testing model conversion for a few gen-ai models.
|
|
16
|
+
import copy
|
|
17
|
+
|
|
18
|
+
import ai_edge_torch
|
|
19
|
+
from ai_edge_torch import config as ai_edge_config
|
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma, gemma2
|
|
21
|
+
from ai_edge_torch.generative.examples.phi2 import phi2
|
|
22
|
+
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
|
|
23
|
+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
|
24
|
+
from ai_edge_torch.testing import model_coverage
|
|
25
|
+
import numpy as np
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
from absl.testing import absltest as googletest
|
|
29
|
+
from tensorflow.lite.python import interpreter
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestModelConversion(googletest.TestCase):
|
|
33
|
+
"""Unit tests that check for model conversion and correctness."""
|
|
34
|
+
|
|
35
|
+
def setUp(self):
|
|
36
|
+
super().setUp()
|
|
37
|
+
# Builder function for an Interpreter that supports custom ops.
|
|
38
|
+
self._interpreter_builder = (
|
|
39
|
+
lambda tflite_model: lambda: interpreter.InterpreterWithCustomOps(
|
|
40
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
|
41
|
+
model_content=tflite_model,
|
|
42
|
+
experimental_default_delegate_latest_features=True,
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@googletest.skipIf(
|
|
47
|
+
ai_edge_config.Config.use_torch_xla,
|
|
48
|
+
reason="tests with custom ops are not supported on oss",
|
|
49
|
+
)
|
|
50
|
+
def test_gemma(self):
|
|
51
|
+
config = gemma.get_fake_model_config()
|
|
52
|
+
model = gemma.Gemma(config)
|
|
53
|
+
|
|
54
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
55
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
56
|
+
tokens[0, :4] = idx
|
|
57
|
+
input_pos = torch.arange(0, 10)
|
|
58
|
+
|
|
59
|
+
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
60
|
+
edge_model.set_interpreter_builder(
|
|
61
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.assertTrue(
|
|
65
|
+
model_coverage.compare_tflite_torch(
|
|
66
|
+
edge_model,
|
|
67
|
+
model,
|
|
68
|
+
(tokens, input_pos),
|
|
69
|
+
num_valid_inputs=1,
|
|
70
|
+
atol=1e-2,
|
|
71
|
+
rtol=1e-5,
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@googletest.skipIf(
|
|
76
|
+
ai_edge_config.Config.use_torch_xla,
|
|
77
|
+
reason="tests with custom ops are not supported on oss",
|
|
78
|
+
)
|
|
79
|
+
def test_gemma2(self):
|
|
80
|
+
config = gemma2.get_fake_model_config()
|
|
81
|
+
model = gemma2.Gemma2(config)
|
|
82
|
+
model.eval()
|
|
83
|
+
|
|
84
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
85
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
86
|
+
tokens[0, :4] = idx
|
|
87
|
+
input_pos = torch.arange(0, 10)
|
|
88
|
+
|
|
89
|
+
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
90
|
+
edge_model.set_interpreter_builder(
|
|
91
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# TODO(b/362840003): debug numerical diff.
|
|
95
|
+
skip_output_check = True
|
|
96
|
+
if not skip_output_check:
|
|
97
|
+
self.assertTrue(
|
|
98
|
+
model_coverage.compare_tflite_torch(
|
|
99
|
+
edge_model,
|
|
100
|
+
model,
|
|
101
|
+
(tokens, input_pos),
|
|
102
|
+
num_valid_inputs=1,
|
|
103
|
+
atol=1e-2,
|
|
104
|
+
rtol=1e-5,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
@googletest.skipIf(
|
|
109
|
+
ai_edge_config.Config.use_torch_xla,
|
|
110
|
+
reason="tests with custom ops are not supported on oss",
|
|
111
|
+
)
|
|
112
|
+
def test_phi2(self):
|
|
113
|
+
config = phi2.get_fake_model_config()
|
|
114
|
+
pytorch_model = phi2.Phi2(config).eval()
|
|
115
|
+
|
|
116
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
|
117
|
+
tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
|
118
|
+
tokens[0, :4] = idx
|
|
119
|
+
input_pos = torch.arange(0, 10)
|
|
120
|
+
|
|
121
|
+
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
122
|
+
edge_model.set_interpreter_builder(
|
|
123
|
+
self._interpreter_builder(edge_model.tflite_model())
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
self.assertTrue(
|
|
127
|
+
model_coverage.compare_tflite_torch(
|
|
128
|
+
edge_model,
|
|
129
|
+
pytorch_model,
|
|
130
|
+
(tokens, input_pos),
|
|
131
|
+
num_valid_inputs=1,
|
|
132
|
+
atol=1e-3,
|
|
133
|
+
rtol=1e-3,
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
if __name__ == "__main__":
|
|
139
|
+
googletest.main()
|
ai_edge_torch/model.py
CHANGED
|
@@ -22,6 +22,7 @@ from __future__ import annotations
|
|
|
22
22
|
|
|
23
23
|
import abc
|
|
24
24
|
import re
|
|
25
|
+
from typing import Callable
|
|
25
26
|
|
|
26
27
|
import numpy.typing as npt
|
|
27
28
|
import tensorflow as tf
|
|
@@ -64,6 +65,24 @@ class TfLiteModel(Model):
|
|
|
64
65
|
tflite_model: A TFlite serialized object.
|
|
65
66
|
"""
|
|
66
67
|
self._tflite_model = tflite_model
|
|
68
|
+
self._interpreter_builder = lambda: tf.lite.Interpreter(
|
|
69
|
+
model_content=self._tflite_model,
|
|
70
|
+
experimental_default_delegate_latest_features=True,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def tflite_model(self) -> bytes:
|
|
74
|
+
"""Returns the wrapped tflite model."""
|
|
75
|
+
return self._tflite_model
|
|
76
|
+
|
|
77
|
+
def set_interpreter_builder(
|
|
78
|
+
self, builder: Callable[[], tf.lite.Interpreter]
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Sets a custom interpreter builder.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
builder: A function that returns a `tf.lite.Interpreter` or its subclass.
|
|
84
|
+
"""
|
|
85
|
+
self._interpreter_builder = builder
|
|
67
86
|
|
|
68
87
|
def __call__(
|
|
69
88
|
self,
|
|
@@ -80,10 +99,7 @@ class TfLiteModel(Model):
|
|
|
80
99
|
signature_name: The name of the signature to be used for inference. The
|
|
81
100
|
default signature is used if not provided.
|
|
82
101
|
"""
|
|
83
|
-
interpreter =
|
|
84
|
-
model_content=self._tflite_model,
|
|
85
|
-
experimental_default_delegate_latest_features=True,
|
|
86
|
-
)
|
|
102
|
+
interpreter = self._interpreter_builder()
|
|
87
103
|
interpreter.allocate_tensors()
|
|
88
104
|
|
|
89
105
|
signature_list = interpreter.get_signature_list()
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.3.0.
|
|
3
|
+
Version: 0.3.0.dev20240901
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
|
-
ai_edge_torch/model.py,sha256=
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
4
|
+
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
|
5
|
+
ai_edge_torch/version.py,sha256=llxFt4Jrb5-zJ8uUalWlL75tHHddPLBl0Nyk0B7ecZU,706
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
@@ -42,7 +42,7 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
|
42
42
|
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
43
43
|
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
44
44
|
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=
|
|
45
|
+
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=EdElPCDLYxnNvkPMJkE3WKvESze1ehgShEk2NnbrXLg,7527
|
|
46
46
|
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
47
47
|
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
|
48
48
|
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
|
|
@@ -52,11 +52,11 @@ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=z
|
|
|
52
52
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
53
53
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
|
|
54
54
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
|
|
55
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
|
55
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=pzD9dYUYg8E6fFACh-8B8G9NHFXOVEWBjf5aDeipU2s,7202
|
|
56
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=ypd6uBb4FgDpuWm_w8JNYBAf4eFxWbYccs8vCgBhi-I,9374
|
|
57
57
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
58
58
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
|
|
59
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=
|
|
59
|
+
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=91mWxEtKgDtUhCAewWNwH_UOOCzy6tPdf6LNRlxZhrc,6700
|
|
60
60
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
61
61
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
|
|
@@ -82,7 +82,7 @@ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.p
|
|
|
82
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
|
83
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
84
84
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
|
|
85
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
|
85
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=JmwU1sniO37vnCFc8dklbd-0ofTZK0PaBv_Ksn1Vq6M,5930
|
|
86
86
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
|
87
87
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
|
88
88
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
@@ -100,7 +100,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3Xqsgr
|
|
|
100
100
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
|
101
101
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
|
102
102
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
103
|
-
ai_edge_torch/generative/quantize/example.py,sha256=
|
|
103
|
+
ai_edge_torch/generative/quantize/example.py,sha256=Bmc-WowIJIfDgt84CNw2LhyLRi7SFcw8BQEu4byTKJU,1523
|
|
104
104
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
105
105
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
|
|
106
106
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
|
@@ -111,7 +111,8 @@ ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha
|
|
|
111
111
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
112
112
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
|
113
113
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
|
114
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256
|
|
114
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=wQLVjMnKHBCVCU_I-xAUZvlOFoDiwYwKQDvCZ2mjtOM,6193
|
|
115
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
|
|
115
116
|
ai_edge_torch/generative/test/test_quantize.py,sha256=JEsk9SAkHK0SFm44K_quISc5yBBS6yvtBP1MDyFHdFw,5344
|
|
116
117
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
117
118
|
ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
|
|
@@ -161,8 +162,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
|
161
162
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
162
163
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
163
164
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240901.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240901.dist-info/METADATA,sha256=6d0enZRdCzGdIp1MfpBysngGttpwRTkPilaM5jXDN2g,1878
|
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240901.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240901.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240901.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|