ai-edge-torch-nightly 0.3.0.dev20240919__py3-none-any.whl → 0.3.0.dev20240921__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/test/test_convert.py +7 -3
- ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py} +9 -7
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +3 -36
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -26
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +55 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +55 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +142 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/openelm/verify.py +6 -4
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/phi/verify.py +14 -4
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/smollm/verify.py +5 -4
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/tiny_llama/verify.py +6 -5
- ai_edge_torch/generative/layers/feed_forward.py +0 -1
- ai_edge_torch/generative/quantize/example.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +5 -5
- ai_edge_torch/generative/utilities/verifier.py +77 -26
- ai_edge_torch/model.py +7 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/RECORD +28 -25
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
|
|
23
23
|
from ai_edge_torch._convert import conversion_utils
|
24
24
|
from ai_edge_torch.testing import model_coverage
|
25
25
|
import numpy as np
|
26
|
-
import tensorflow as tf
|
27
26
|
import torch
|
28
27
|
from torch import nn
|
29
28
|
import torchvision
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
32
32
|
|
33
33
|
|
34
34
|
@dataclasses.dataclass
|
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
|
|
466
466
|
np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
|
467
467
|
np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
|
468
468
|
|
469
|
-
interpreter =
|
469
|
+
interpreter = tfl_interpreter.Interpreter(
|
470
|
+
model_content=edge_model._tflite_model
|
471
|
+
)
|
470
472
|
runner = interpreter.get_signature_runner("serving_default")
|
471
473
|
output_details = runner.get_output_details()
|
472
474
|
self.assertIn("x", output_details.keys())
|
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
|
|
477
479
|
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
478
480
|
model.eval()
|
479
481
|
edge_model = ai_edge_torch.convert(model, args, kwargs)
|
480
|
-
interpreter =
|
482
|
+
interpreter = tfl_interpreter.Interpreter(
|
483
|
+
model_content=edge_model._tflite_model
|
484
|
+
)
|
481
485
|
runner = interpreter.get_signature_runner("serving_default")
|
482
486
|
input_details = runner.get_input_details()
|
483
487
|
self.assertEqual(input_details.keys(), flat_inputs.keys())
|
ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py}
RENAMED
@@ -13,14 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting a
|
16
|
+
"""Example of converting a Gemma1 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.gemma import
|
23
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
25
|
|
26
26
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -51,12 +51,14 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
51
51
|
|
52
52
|
|
53
53
|
def main(_):
|
54
|
-
pytorch_model =
|
54
|
+
pytorch_model = gemma1.build_2b_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = gemma2.build_2b_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -13,10 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of building a
|
17
|
-
|
18
|
-
import os
|
19
|
-
import pathlib
|
16
|
+
"""Example of building a Gemma1 model."""
|
20
17
|
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
@@ -24,7 +21,6 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -32,13 +28,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
32
28
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
33
29
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
34
30
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
35
|
-
|
36
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
37
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
31
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
38
32
|
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
39
33
|
pre_attn_norm="model.layers.{}.input_layernorm",
|
40
34
|
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
41
|
-
embedding="
|
35
|
+
embedding="embedder",
|
42
36
|
final_norm="model.norm",
|
43
37
|
lm_head=None,
|
44
38
|
)
|
@@ -192,30 +186,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
192
186
|
loader.load(model, strict=False)
|
193
187
|
model.eval()
|
194
188
|
return model
|
195
|
-
|
196
|
-
|
197
|
-
def define_and_run_2b(checkpoint_path: str) -> None:
|
198
|
-
"""Instantiates and runs a Gemma 2B model."""
|
199
|
-
|
200
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
201
|
-
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
202
|
-
|
203
|
-
kv_cache_max_len = 1024
|
204
|
-
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
205
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
206
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
207
|
-
tokens[0, :4] = idx
|
208
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
209
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
210
|
-
output = model.forward(tokens, input_pos, kv)
|
211
|
-
print("comparing with goldens..")
|
212
|
-
assert torch.allclose(
|
213
|
-
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
214
|
-
)
|
215
|
-
|
216
|
-
|
217
|
-
if __name__ == "__main__":
|
218
|
-
input_checkpoint_path = os.path.join(
|
219
|
-
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
220
|
-
)
|
221
|
-
define_and_run_2b(input_checkpoint_path)
|
@@ -267,29 +267,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
267
267
|
loader.load(model, strict=False)
|
268
268
|
model.eval()
|
269
269
|
return model
|
270
|
-
|
271
|
-
|
272
|
-
def define_and_run_2b(checkpoint_path: str) -> None:
|
273
|
-
"""Instantiates and runs a Gemma2 2B model."""
|
274
|
-
|
275
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
276
|
-
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
277
|
-
print("Running GEMMA 2")
|
278
|
-
kv_cache_max_len = 1024
|
279
|
-
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
280
|
-
toks = torch.from_numpy(
|
281
|
-
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
282
|
-
)
|
283
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
284
|
-
tokens[0, :9] = toks
|
285
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
286
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
287
|
-
out = model.forward(tokens, input_pos, kv)
|
288
|
-
out_final = out["logits"][0, 8, :]
|
289
|
-
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
290
|
-
|
291
|
-
|
292
|
-
if __name__ == "__main__":
|
293
|
-
torch.set_printoptions(sci_mode=True)
|
294
|
-
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
295
|
-
define_and_run_2b(path)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Gemma1 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
|
+
from ai_edge_torch.generative.examples.gemma import verify_util
|
22
|
+
from ai_edge_torch.generative.utilities import verifier
|
23
|
+
import kagglehub
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"What is the meaning of life?",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
31
|
+
"max_new_tokens",
|
32
|
+
30,
|
33
|
+
"The maximum size of the generated tokens.",
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def main(_):
|
38
|
+
checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
|
39
|
+
|
40
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
41
|
+
reauthored_model = gemma1.build_2b_model(checkpoint)
|
42
|
+
|
43
|
+
verify_util.verify_reauthored_gemma_model(
|
44
|
+
checkpoint=checkpoint,
|
45
|
+
variant="2b",
|
46
|
+
reauthored_model=reauthored_model,
|
47
|
+
weight_filename="gemma-2b-it.ckpt",
|
48
|
+
generate_prompts=_PROMPTS.value,
|
49
|
+
forward_input_ids=[[1, 2, 3, 4]],
|
50
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
app.run(main)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Gemma2 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
21
|
+
from ai_edge_torch.generative.examples.gemma import verify_util
|
22
|
+
from ai_edge_torch.generative.utilities import verifier
|
23
|
+
import kagglehub
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"What is the meaning of life?",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
31
|
+
"max_new_tokens",
|
32
|
+
30,
|
33
|
+
"The maximum size of the generated tokens.",
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def main(_):
|
38
|
+
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
|
39
|
+
|
40
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
41
|
+
reauthored_model = gemma2.build_2b_model(checkpoint)
|
42
|
+
|
43
|
+
verify_util.verify_reauthored_gemma_model(
|
44
|
+
checkpoint=checkpoint,
|
45
|
+
variant="2b-v2",
|
46
|
+
reauthored_model=reauthored_model,
|
47
|
+
generate_prompts=_PROMPTS.value,
|
48
|
+
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
49
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
50
|
+
atol=1e-04,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
app.run(main)
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utility functions to verify the reauthored Gemma model."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
import os
|
20
|
+
from typing import List, Tuple
|
21
|
+
|
22
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
from gemma import config as gemma_config
|
25
|
+
from gemma import model as gemma_model
|
26
|
+
import torch
|
27
|
+
|
28
|
+
|
29
|
+
@dataclasses.dataclass
|
30
|
+
class _Output:
|
31
|
+
logits: torch.Tensor
|
32
|
+
|
33
|
+
|
34
|
+
class GemmaWrapper(verifier.ModelWrapper):
|
35
|
+
"""Gemma model wrapper for verification.
|
36
|
+
|
37
|
+
Verifier calls model.forward() with maxium sequence length (1024) expecting
|
38
|
+
the output has 'logits' field while Gemma gets the input tokens with the
|
39
|
+
actual length and returns logits in a tuple.
|
40
|
+
|
41
|
+
Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
|
42
|
+
inside model.generate().
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, model: torch.nn.Module, max_new_tokens: int):
|
46
|
+
super().__init__(model)
|
47
|
+
self.max_new_tokens = max_new_tokens
|
48
|
+
|
49
|
+
def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
|
50
|
+
for i in range(tokens.shape[1]):
|
51
|
+
if tokens[0, i] == 0:
|
52
|
+
return i
|
53
|
+
return tokens.shape[1]
|
54
|
+
|
55
|
+
def _get_kv_caches(
|
56
|
+
self, max_seq_len: int
|
57
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
58
|
+
config = self.model.config
|
59
|
+
cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
|
60
|
+
cache = torch.zeros(cache_size)
|
61
|
+
return [
|
62
|
+
(cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
|
63
|
+
]
|
64
|
+
|
65
|
+
def forward(self, tokens: torch.Tensor) -> _Output:
|
66
|
+
"""Forwards the model after reducing input tokens to the actual length."""
|
67
|
+
actual_input_len = self._get_actual_input_len(tokens)
|
68
|
+
input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
|
69
|
+
mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
|
70
|
+
_, logits = self.model.forward(
|
71
|
+
input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
|
72
|
+
input_positions=input_pos,
|
73
|
+
kv_write_indices=None,
|
74
|
+
kv_caches=self._get_kv_caches(tokens.shape[1]),
|
75
|
+
mask=mask_cache.index_select(2, input_pos),
|
76
|
+
output_positions=input_pos,
|
77
|
+
temperatures=None,
|
78
|
+
top_ps=torch.tensor([1.0], dtype=torch.float),
|
79
|
+
top_ks=torch.tensor([1], dtype=torch.long),
|
80
|
+
)
|
81
|
+
return _Output(logits.float())
|
82
|
+
|
83
|
+
def generate(self, tokens: torch.Tensor) -> torch.Tensor:
|
84
|
+
"""Generates the response after decoding the tokens into a string."""
|
85
|
+
prompts = self.model.tokenizer.decode(tokens[0].tolist())
|
86
|
+
response = self.model.generate(
|
87
|
+
prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
|
88
|
+
)
|
89
|
+
return torch.tensor([self.model.tokenizer.encode(prompts + response)])
|
90
|
+
|
91
|
+
|
92
|
+
class TokenizerWrapper(torch.nn.Module):
|
93
|
+
"""Tokenizer wrapper for verification.
|
94
|
+
|
95
|
+
Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
|
96
|
+
tokenizer expects tokens in a list.
|
97
|
+
"""
|
98
|
+
|
99
|
+
def __init__(self, tokenizer: torch.nn.Module):
|
100
|
+
super().__init__()
|
101
|
+
self.tokenizer = tokenizer
|
102
|
+
|
103
|
+
def encode(self, text: str, **_) -> torch.Tensor:
|
104
|
+
"""Adds one more dimension to the output of the tokenizer."""
|
105
|
+
return torch.tensor([self.tokenizer.encode(text)])
|
106
|
+
|
107
|
+
def decode(self, tokens: torch.Tensor) -> str:
|
108
|
+
"""Decodes the token sequence after converting to a list."""
|
109
|
+
return self.tokenizer.decode(tokens.tolist())
|
110
|
+
|
111
|
+
|
112
|
+
def verify_reauthored_gemma_model(
|
113
|
+
checkpoint: str,
|
114
|
+
variant: str,
|
115
|
+
reauthored_model: torch.nn.Module,
|
116
|
+
generate_prompts: List[str],
|
117
|
+
forward_input_ids: List[List[int]],
|
118
|
+
weight_filename: str = "model.ckpt",
|
119
|
+
tokenizer_filename: str = "tokenizer.model",
|
120
|
+
max_new_tokens: int = 20,
|
121
|
+
rtol: float = 1e-05,
|
122
|
+
atol: float = 1e-05,
|
123
|
+
):
|
124
|
+
"""Verifies the reauthored Gemma model against the original model."""
|
125
|
+
config = gemma_config.get_model_config(variant)
|
126
|
+
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
|
127
|
+
# Use float32 to be compatible with the reauthored model.
|
128
|
+
config.dtype = torch.float32
|
129
|
+
|
130
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
131
|
+
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
132
|
+
original_model.load_weights(os.path.join(checkpoint, weight_filename))
|
133
|
+
|
134
|
+
verifier.verify_reauthored_model(
|
135
|
+
original_model=GemmaWrapper(original_model, max_new_tokens),
|
136
|
+
reauthored_model=reauthored_model,
|
137
|
+
tokenizer=TokenizerWrapper(original_model.tokenizer),
|
138
|
+
generate_prompts=generate_prompts,
|
139
|
+
forward_input_ids=forward_input_ids,
|
140
|
+
rtol=rtol,
|
141
|
+
atol=atol,
|
142
|
+
)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = openelm.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -33,8 +33,10 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
33
33
|
def main(_):
|
34
34
|
checkpoint = "apple/OpenELM-3B"
|
35
35
|
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
-
|
37
|
-
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(
|
38
|
+
checkpoint, trust_remote_code=True
|
39
|
+
),
|
38
40
|
)
|
39
41
|
|
40
42
|
# Locate the cached dir.
|
@@ -50,10 +52,10 @@ def main(_):
|
|
50
52
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
51
53
|
|
52
54
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
55
|
+
original_model=wrapper_model,
|
54
56
|
reauthored_model=reauthored_model,
|
55
57
|
tokenizer=tokenizer,
|
56
|
-
|
58
|
+
generate_prompts=_PROMPTS.value,
|
57
59
|
)
|
58
60
|
|
59
61
|
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = phi2.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -24,15 +24,25 @@ import transformers
|
|
24
24
|
|
25
25
|
_PROMPTS = flags.DEFINE_multi_string(
|
26
26
|
"prompts",
|
27
|
-
"
|
27
|
+
"Instruct: Write an email about the weather Output:",
|
28
28
|
"The input prompts to generate answers.",
|
29
29
|
)
|
30
30
|
|
31
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
32
|
+
"max_new_tokens",
|
33
|
+
30,
|
34
|
+
"The maximum size of the generated tokens.",
|
35
|
+
)
|
31
36
|
|
32
37
|
def main(_):
|
33
38
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
34
39
|
verifier.log_msg("Loading the original model from", checkpoint)
|
35
|
-
|
40
|
+
generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
|
41
|
+
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
42
|
+
wrapper_model = verifier.ModelWrapper(
|
43
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
44
|
+
hf_generation_config=generation_config,
|
45
|
+
)
|
36
46
|
|
37
47
|
verifier.log_msg("Building the reauthored model from", checkpoint)
|
38
48
|
reauthored_model = phi2.build_model(checkpoint)
|
@@ -41,10 +51,10 @@ def main(_):
|
|
41
51
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
42
52
|
|
43
53
|
verifier.verify_reauthored_model(
|
44
|
-
original_model=
|
54
|
+
original_model=wrapper_model,
|
45
55
|
reauthored_model=reauthored_model,
|
46
56
|
tokenizer=tokenizer,
|
47
|
-
|
57
|
+
generate_prompts=_PROMPTS.value,
|
48
58
|
atol=1e-03,
|
49
59
|
)
|
50
60
|
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = smollm.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -33,8 +33,9 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
33
33
|
def main(_):
|
34
34
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
35
|
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
-
|
37
|
-
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
38
|
+
)
|
38
39
|
# Locate the cached dir.
|
39
40
|
cached_config_file = transformers.utils.cached_file(
|
40
41
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -47,10 +48,10 @@ def main(_):
|
|
47
48
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
48
49
|
|
49
50
|
verifier.verify_reauthored_model(
|
50
|
-
original_model=
|
51
|
+
original_model=wrapper_model,
|
51
52
|
reauthored_model=reauthored_model,
|
52
53
|
tokenizer=tokenizer,
|
53
|
-
|
54
|
+
generate_prompts=_PROMPTS.value,
|
54
55
|
atol=1e-04,
|
55
56
|
)
|
56
57
|
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = tiny_llama.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -33,10 +33,11 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
33
33
|
def main(_):
|
34
34
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
35
|
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
-
|
37
|
-
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(
|
38
|
+
checkpoint, trust_remote_code=True
|
39
|
+
),
|
38
40
|
)
|
39
|
-
|
40
41
|
# Locate the cached dir.
|
41
42
|
cached_config_file = transformers.utils.cached_file(
|
42
43
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -49,10 +50,10 @@ def main(_):
|
|
49
50
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
50
51
|
|
51
52
|
verifier.verify_reauthored_model(
|
52
|
-
original_model=
|
53
|
+
original_model=wrapper_model,
|
53
54
|
reauthored_model=reauthored_model,
|
54
55
|
tokenizer=tokenizer,
|
55
|
-
|
56
|
+
generate_prompts=_PROMPTS.value,
|
56
57
|
atol=1e-04,
|
57
58
|
)
|
58
59
|
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import ai_edge_torch
|
17
|
-
from ai_edge_torch.generative.examples.gemma import
|
17
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
18
18
|
from ai_edge_torch.generative.quantize import quant_recipes
|
19
19
|
import numpy as np
|
20
20
|
import torch
|
@@ -22,8 +22,8 @@ import torch
|
|
22
22
|
|
23
23
|
def main():
|
24
24
|
# Build a PyTorch model as usual
|
25
|
-
config =
|
26
|
-
model =
|
25
|
+
config = gemma1.get_fake_model_config()
|
26
|
+
model = gemma1.Gemma(config)
|
27
27
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
28
28
|
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
29
29
|
tokens[0, :4] = idx
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
@@ -28,7 +28,7 @@ import numpy as np
|
|
28
28
|
import torch
|
29
29
|
|
30
30
|
from absl.testing import absltest as googletest
|
31
|
-
from
|
31
|
+
from ai_edge_litert import interpreter
|
32
32
|
|
33
33
|
|
34
34
|
class TestModelConversion(googletest.TestCase):
|
@@ -82,9 +82,9 @@ class TestModelConversion(googletest.TestCase):
|
|
82
82
|
ai_edge_config.Config.use_torch_xla,
|
83
83
|
reason="tests with custom ops are not supported on oss",
|
84
84
|
)
|
85
|
-
def
|
86
|
-
config =
|
87
|
-
pytorch_model =
|
85
|
+
def test_gemma1(self):
|
86
|
+
config = gemma1.get_fake_model_config()
|
87
|
+
pytorch_model = gemma1.Gemma(config).eval()
|
88
88
|
self._test_model(
|
89
89
|
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
90
90
|
)
|
@@ -16,17 +16,65 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import datetime
|
19
|
-
from typing import List
|
19
|
+
from typing import List, Optional, Union
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
-
import numpy as np
|
23
22
|
import torch
|
23
|
+
import transformers
|
24
24
|
|
25
25
|
|
26
26
|
def log_msg(*args):
|
27
27
|
print("[%s]" % datetime.datetime.now(), *args)
|
28
28
|
|
29
29
|
|
30
|
+
class ModelWrapper(torch.nn.Module):
|
31
|
+
"""A wrapper for the model to be verified, this could be a HuggingFace model
|
32
|
+
|
33
|
+
or a regular PyTorch model.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
model: torch.nn.Module,
|
39
|
+
model_format: str = "huggingface",
|
40
|
+
hf_generation_config: Optional[transformers.GenerationConfig] = None,
|
41
|
+
):
|
42
|
+
"""Initializes the wrapper.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
model (torch.nn.Module): The original model. This could be a model built
|
46
|
+
from HuggingFace transformers, or a regular PyTorch model.
|
47
|
+
model_format (str): The format of the model. It should be either
|
48
|
+
"huggingface" or "pytorch".
|
49
|
+
hf_generation_config (transformers.GenerationConfig): The HuggingFace
|
50
|
+
generation config. This config will only be used if the underlying model
|
51
|
+
is built from HuggingFace transformers.
|
52
|
+
"""
|
53
|
+
super().__init__()
|
54
|
+
self.model = model
|
55
|
+
self.model_format = model_format
|
56
|
+
self.hf_generation_config = hf_generation_config
|
57
|
+
|
58
|
+
def generate(
|
59
|
+
self, inputs: torch.Tensor
|
60
|
+
) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
|
61
|
+
if self.model_format == "huggingface":
|
62
|
+
return self.model.generate(
|
63
|
+
inputs=inputs, generation_config=self.hf_generation_config
|
64
|
+
)
|
65
|
+
else:
|
66
|
+
raise NotImplementedError(
|
67
|
+
"generate() is not implemented for model format: %s"
|
68
|
+
% self.model_format
|
69
|
+
)
|
70
|
+
|
71
|
+
def forward(
|
72
|
+
self,
|
73
|
+
inputs: torch.Tensor,
|
74
|
+
):
|
75
|
+
return self.model.forward(inputs)
|
76
|
+
|
77
|
+
|
30
78
|
def forward(
|
31
79
|
model: torch.nn.Module,
|
32
80
|
tokens: torch.Tensor,
|
@@ -75,9 +123,9 @@ def generate(
|
|
75
123
|
|
76
124
|
|
77
125
|
def verify_with_input_ids(
|
78
|
-
original_model:
|
126
|
+
original_model: ModelWrapper,
|
79
127
|
reauthored_model: torch.nn.Module,
|
80
|
-
input_ids:
|
128
|
+
input_ids: List[int],
|
81
129
|
kv_cache_max_len: int = 1024,
|
82
130
|
rtol: float = 1e-05,
|
83
131
|
atol: float = 1e-05,
|
@@ -87,10 +135,10 @@ def verify_with_input_ids(
|
|
87
135
|
It compares only one outputs from the original and the reauthored model.
|
88
136
|
|
89
137
|
Args:
|
90
|
-
original_model (
|
138
|
+
original_model (ModelWrapper): The original model.
|
91
139
|
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
92
140
|
Generative API.
|
93
|
-
input_ids (
|
141
|
+
input_ids (List[int]): The input token IDs to forward with.
|
94
142
|
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
95
143
|
rtol (float): The relative tolerance for the comparison.
|
96
144
|
atol (float): The absolute tolerance for the comparison.
|
@@ -99,18 +147,17 @@ def verify_with_input_ids(
|
|
99
147
|
True if the model reauthored generates the same output of the original.
|
100
148
|
"""
|
101
149
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
102
|
-
|
103
|
-
tokens[0, :input_ids_len] = input_ids
|
150
|
+
tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
|
104
151
|
|
105
152
|
log_msg("Forwarding the original model...")
|
106
153
|
outputs_original = original_model.forward(tokens)
|
107
|
-
logits_original = outputs_original.logits[0,
|
154
|
+
logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
|
108
155
|
log_msg("logits_original: ", logits_original)
|
109
156
|
|
110
157
|
log_msg("Forwarding the reauthored model...")
|
111
158
|
kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
|
112
159
|
outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
|
113
|
-
logits_reauthored = outputs_reauthored[0][0,
|
160
|
+
logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
|
114
161
|
log_msg("logits_reauthored:", logits_reauthored)
|
115
162
|
|
116
163
|
return torch.allclose(
|
@@ -119,7 +166,7 @@ def verify_with_input_ids(
|
|
119
166
|
|
120
167
|
|
121
168
|
def verify_model_with_prompts(
|
122
|
-
original_model:
|
169
|
+
original_model: ModelWrapper,
|
123
170
|
reauthored_model: torch.nn.Module,
|
124
171
|
tokenizer: torch.nn.Module,
|
125
172
|
prompts: str,
|
@@ -130,7 +177,7 @@ def verify_model_with_prompts(
|
|
130
177
|
original and the reauthored model.
|
131
178
|
|
132
179
|
Args:
|
133
|
-
original_model (
|
180
|
+
original_model (ModelWrapper): The original model.
|
134
181
|
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
135
182
|
Generative API.
|
136
183
|
tokenizer (torch.nn.Module): The tokenizer.
|
@@ -156,10 +203,11 @@ def verify_model_with_prompts(
|
|
156
203
|
|
157
204
|
|
158
205
|
def verify_reauthored_model(
|
159
|
-
original_model:
|
206
|
+
original_model: ModelWrapper,
|
160
207
|
reauthored_model: torch.nn.Module,
|
161
208
|
tokenizer: torch.nn.Module,
|
162
|
-
|
209
|
+
generate_prompts: List[str],
|
210
|
+
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
163
211
|
rtol: float = 1e-05,
|
164
212
|
atol: float = 1e-05,
|
165
213
|
):
|
@@ -174,26 +222,29 @@ def verify_reauthored_model(
|
|
174
222
|
It prints out "PASS" or "FAILED" to the console.
|
175
223
|
|
176
224
|
Args:
|
177
|
-
original_model (
|
225
|
+
original_model (ModelWrapper): The original model.
|
178
226
|
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
179
227
|
Generative API.
|
180
228
|
tokenizer (torch.nn.Module): The tokenizer.
|
181
|
-
|
229
|
+
generate_prompts (List[str]): List of the input prompts to generate answers.
|
230
|
+
forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
|
231
|
+
forward with.
|
182
232
|
rtol (float): The relative tolerance for the comparison.
|
183
233
|
atol (float): The absolute tolerance for the comparison.
|
184
234
|
"""
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
235
|
+
for input_ids in forward_input_ids:
|
236
|
+
log_msg("Verifying the reauthored model with input IDs:", input_ids)
|
237
|
+
if verify_with_input_ids(
|
238
|
+
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
239
|
+
):
|
240
|
+
log_msg("PASS")
|
241
|
+
else:
|
242
|
+
log_msg("FAILED")
|
192
243
|
|
193
|
-
for
|
194
|
-
log_msg("Verifying the reauthored model with prompts:",
|
244
|
+
for prompts in generate_prompts:
|
245
|
+
log_msg("Verifying the reauthored model with prompts:", prompts)
|
195
246
|
if verify_model_with_prompts(
|
196
|
-
original_model, reauthored_model, tokenizer,
|
247
|
+
original_model, reauthored_model, tokenizer, prompts
|
197
248
|
):
|
198
249
|
log_msg("PASS")
|
199
250
|
else:
|
ai_edge_torch/model.py
CHANGED
@@ -27,6 +27,8 @@ from typing import Callable
|
|
27
27
|
import numpy.typing as npt
|
28
28
|
import tensorflow as tf
|
29
29
|
|
30
|
+
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
31
|
+
|
30
32
|
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
31
33
|
|
32
34
|
|
@@ -65,7 +67,7 @@ class TfLiteModel(Model):
|
|
65
67
|
tflite_model: A TFlite serialized object.
|
66
68
|
"""
|
67
69
|
self._tflite_model = tflite_model
|
68
|
-
self._interpreter_builder = lambda:
|
70
|
+
self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
|
69
71
|
model_content=self._tflite_model,
|
70
72
|
experimental_default_delegate_latest_features=True,
|
71
73
|
)
|
@@ -75,12 +77,13 @@ class TfLiteModel(Model):
|
|
75
77
|
return self._tflite_model
|
76
78
|
|
77
79
|
def set_interpreter_builder(
|
78
|
-
self, builder: Callable[[],
|
80
|
+
self, builder: Callable[[], tfl_interpreter.Interpreter]
|
79
81
|
) -> None:
|
80
82
|
"""Sets a custom interpreter builder.
|
81
83
|
|
82
84
|
Args:
|
83
|
-
builder: A function that returns a `
|
85
|
+
builder: A function that returns a `tfl_interpreter.Interpreter` or its
|
86
|
+
subclass.
|
84
87
|
"""
|
85
88
|
self._interpreter_builder = builder
|
86
89
|
|
@@ -166,7 +169,7 @@ class TfLiteModel(Model):
|
|
166
169
|
|
167
170
|
# Check if this is indeed a tflite model:
|
168
171
|
try:
|
169
|
-
interpreter =
|
172
|
+
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
170
173
|
interpreter.get_signature_list()
|
171
174
|
except:
|
172
175
|
return None
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240921
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,8 +2,8 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
|
-
ai_edge_torch/model.py,sha256=
|
6
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
+
ai_edge_torch/version.py,sha256=t9zajdsiowClI2fG0RkKVonPF-SUx9UBuUDOEZFU9y4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -25,7 +25,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
25
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
27
27
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
28
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
28
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=40QRxQFNeSRr4dLXJkzG-wKUlvJtsfv62cdvRrmBv5w,15097
|
29
29
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
30
30
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
31
31
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -39,22 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
39
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/gemma/
|
43
|
-
ai_edge_torch/generative/examples/gemma/
|
44
|
-
ai_edge_torch/generative/examples/gemma/
|
45
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
44
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=KsM6VlzluTqbodG24IFr3biPxBrLay0z0gmnG0bcU2U,9277
|
46
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=B14IR4mIw6qBVUbiIRdfdUzHMCIJCJ0RFPsYOxA46qc,1776
|
47
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax1m4WTAnC3gMxBpcvGA7-xTO1Iuw,1802
|
48
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
|
46
49
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
48
51
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
|
49
|
-
ai_edge_torch/generative/examples/openelm/verify.py,sha256=
|
52
|
+
ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
|
50
53
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
51
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
54
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
52
55
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
53
|
-
ai_edge_torch/generative/examples/phi/verify.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
|
54
57
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
58
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
56
59
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
57
|
-
ai_edge_torch/generative/examples/smollm/verify.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
|
58
61
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
62
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
60
63
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
@@ -78,16 +81,16 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
|
|
78
81
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
79
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
80
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
81
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
84
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
|
82
85
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
|
83
|
-
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=
|
86
|
+
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LUChL5tA7FHL_DlTg5QKvGInmH9AwVVw9a-omcndiz8,2095
|
84
87
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
85
88
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
86
89
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
87
90
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
88
91
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
89
92
|
ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
|
90
|
-
ai_edge_torch/generative/layers/feed_forward.py,sha256=
|
93
|
+
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
91
94
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
92
95
|
ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
|
93
96
|
ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
|
@@ -98,7 +101,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW
|
|
98
101
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
99
102
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
|
100
103
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
101
|
-
ai_edge_torch/generative/quantize/example.py,sha256=
|
104
|
+
ai_edge_torch/generative/quantize/example.py,sha256=tlACaRsz6lqOxakzpXVFJZYfFKOiFqetcYVJqWVRdPE,1542
|
102
105
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
103
106
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
|
104
107
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
@@ -107,8 +110,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
107
110
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
108
111
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
109
112
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
110
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
111
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
113
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
|
114
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
|
112
115
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
113
116
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
114
117
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -116,7 +119,7 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
|
|
116
119
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
117
120
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
118
121
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
122
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=vU9KgmFS7I9jNS_3H2SWROx-rbNqtMKgQC2MRhdqQ4g,8803
|
120
123
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
121
124
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
122
125
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -163,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
163
166
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
164
167
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
165
168
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
170
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/METADATA,sha256=SWy7BhOQDe0_SBF17deNndzt1bEYy7iXUxy0KznIPYM,1859
|
171
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
172
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
173
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/RECORD,,
|
File without changes
|
File without changes
|