ai-edge-torch-nightly 0.3.0.dev20240919__py3-none-any.whl → 0.3.0.dev20240921__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/test/test_convert.py +7 -3
- ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py} +9 -7
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +3 -36
- ai_edge_torch/generative/examples/gemma/gemma2.py +0 -26
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +55 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +55 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +142 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/openelm/verify.py +6 -4
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/phi/verify.py +14 -4
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/smollm/verify.py +5 -4
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +6 -4
- ai_edge_torch/generative/examples/tiny_llama/verify.py +6 -5
- ai_edge_torch/generative/layers/feed_forward.py +0 -1
- ai_edge_torch/generative/quantize/example.py +3 -3
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +5 -5
- ai_edge_torch/generative/utilities/verifier.py +77 -26
- ai_edge_torch/model.py +7 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/RECORD +28 -25
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240919.dist-info → ai_edge_torch_nightly-0.3.0.dev20240921.dist-info}/top_level.txt +0 -0
@@ -23,12 +23,12 @@ from ai_edge_torch import config
|
|
23
23
|
from ai_edge_torch._convert import conversion_utils
|
24
24
|
from ai_edge_torch.testing import model_coverage
|
25
25
|
import numpy as np
|
26
|
-
import tensorflow as tf
|
27
26
|
import torch
|
28
27
|
from torch import nn
|
29
28
|
import torchvision
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
32
32
|
|
33
33
|
|
34
34
|
@dataclasses.dataclass
|
@@ -466,7 +466,9 @@ class TestConvert(googletest.TestCase):
|
|
466
466
|
np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
|
467
467
|
np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
|
468
468
|
|
469
|
-
interpreter =
|
469
|
+
interpreter = tfl_interpreter.Interpreter(
|
470
|
+
model_content=edge_model._tflite_model
|
471
|
+
)
|
470
472
|
runner = interpreter.get_signature_runner("serving_default")
|
471
473
|
output_details = runner.get_output_details()
|
472
474
|
self.assertIn("x", output_details.keys())
|
@@ -477,7 +479,9 @@ class TestConvert(googletest.TestCase):
|
|
477
479
|
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
478
480
|
model.eval()
|
479
481
|
edge_model = ai_edge_torch.convert(model, args, kwargs)
|
480
|
-
interpreter =
|
482
|
+
interpreter = tfl_interpreter.Interpreter(
|
483
|
+
model_content=edge_model._tflite_model
|
484
|
+
)
|
481
485
|
runner = interpreter.get_signature_runner("serving_default")
|
482
486
|
input_details = runner.get_input_details()
|
483
487
|
self.assertEqual(input_details.keys(), flat_inputs.keys())
|
ai_edge_torch/generative/examples/gemma/{convert_to_tflite.py → convert_gemma1_to_tflite.py}
RENAMED
@@ -13,14 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of converting a
|
16
|
+
"""Example of converting a Gemma1 model to multi-signature tflite model."""
|
17
17
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
21
|
from absl import app
|
22
22
|
from absl import flags
|
23
|
-
from ai_edge_torch.generative.examples.gemma import
|
23
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
24
24
|
from ai_edge_torch.generative.utilities import converter
|
25
25
|
|
26
26
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -51,12 +51,14 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
51
51
|
|
52
52
|
|
53
53
|
def main(_):
|
54
|
-
pytorch_model =
|
54
|
+
pytorch_model = gemma1.build_2b_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = gemma2.build_2b_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -13,10 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example of building a
|
17
|
-
|
18
|
-
import os
|
19
|
-
import pathlib
|
16
|
+
"""Example of building a Gemma1 model."""
|
20
17
|
|
21
18
|
from ai_edge_torch.generative.layers import attention
|
22
19
|
from ai_edge_torch.generative.layers import builder
|
@@ -24,7 +21,6 @@ from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
|
24
21
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
25
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
26
23
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
|
-
import numpy as np
|
28
24
|
import torch
|
29
25
|
from torch import nn
|
30
26
|
|
@@ -32,13 +28,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
|
32
28
|
ff_up_proj="model.layers.{}.mlp.up_proj",
|
33
29
|
ff_down_proj="model.layers.{}.mlp.down_proj",
|
34
30
|
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
35
|
-
|
36
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
37
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
31
|
+
attn_fused_qkv_proj="model.layers.{}.self_attn.qkv_proj",
|
38
32
|
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
39
33
|
pre_attn_norm="model.layers.{}.input_layernorm",
|
40
34
|
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
41
|
-
embedding="
|
35
|
+
embedding="embedder",
|
42
36
|
final_norm="model.norm",
|
43
37
|
lm_head=None,
|
44
38
|
)
|
@@ -192,30 +186,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
192
186
|
loader.load(model, strict=False)
|
193
187
|
model.eval()
|
194
188
|
return model
|
195
|
-
|
196
|
-
|
197
|
-
def define_and_run_2b(checkpoint_path: str) -> None:
|
198
|
-
"""Instantiates and runs a Gemma 2B model."""
|
199
|
-
|
200
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
201
|
-
gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
|
202
|
-
|
203
|
-
kv_cache_max_len = 1024
|
204
|
-
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
205
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
206
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
207
|
-
tokens[0, :4] = idx
|
208
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
209
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
210
|
-
output = model.forward(tokens, input_pos, kv)
|
211
|
-
print("comparing with goldens..")
|
212
|
-
assert torch.allclose(
|
213
|
-
gemma_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-02
|
214
|
-
)
|
215
|
-
|
216
|
-
|
217
|
-
if __name__ == "__main__":
|
218
|
-
input_checkpoint_path = os.path.join(
|
219
|
-
pathlib.Path.home(), "Downloads/llm_data/gemma-2b"
|
220
|
-
)
|
221
|
-
define_and_run_2b(input_checkpoint_path)
|
@@ -267,29 +267,3 @@ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
|
|
267
267
|
loader.load(model, strict=False)
|
268
268
|
model.eval()
|
269
269
|
return model
|
270
|
-
|
271
|
-
|
272
|
-
def define_and_run_2b(checkpoint_path: str) -> None:
|
273
|
-
"""Instantiates and runs a Gemma2 2B model."""
|
274
|
-
|
275
|
-
current_dir = pathlib.Path(__file__).parent.resolve()
|
276
|
-
gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
|
277
|
-
print("Running GEMMA 2")
|
278
|
-
kv_cache_max_len = 1024
|
279
|
-
model = build_2b_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
|
280
|
-
toks = torch.from_numpy(
|
281
|
-
np.array([2, 651, 9456, 576, 573, 3520, 3858, 603, 235248])
|
282
|
-
)
|
283
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
284
|
-
tokens[0, :9] = toks
|
285
|
-
input_pos = torch.arange(0, kv_cache_max_len, dtype=torch.int)
|
286
|
-
kv = kv_utils.KVCache.from_model_config(model.config)
|
287
|
-
out = model.forward(tokens, input_pos, kv)
|
288
|
-
out_final = out["logits"][0, 8, :]
|
289
|
-
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
290
|
-
|
291
|
-
|
292
|
-
if __name__ == "__main__":
|
293
|
-
torch.set_printoptions(sci_mode=True)
|
294
|
-
path = os.path.join(pathlib.Path.home(), "Downloads/llm_data/gemma2-2b")
|
295
|
-
define_and_run_2b(path)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Gemma1 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
|
+
from ai_edge_torch.generative.examples.gemma import verify_util
|
22
|
+
from ai_edge_torch.generative.utilities import verifier
|
23
|
+
import kagglehub
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"What is the meaning of life?",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
31
|
+
"max_new_tokens",
|
32
|
+
30,
|
33
|
+
"The maximum size of the generated tokens.",
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def main(_):
|
38
|
+
checkpoint = kagglehub.model_download("google/gemma/pyTorch/2b-it")
|
39
|
+
|
40
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
41
|
+
reauthored_model = gemma1.build_2b_model(checkpoint)
|
42
|
+
|
43
|
+
verify_util.verify_reauthored_gemma_model(
|
44
|
+
checkpoint=checkpoint,
|
45
|
+
variant="2b",
|
46
|
+
reauthored_model=reauthored_model,
|
47
|
+
weight_filename="gemma-2b-it.ckpt",
|
48
|
+
generate_prompts=_PROMPTS.value,
|
49
|
+
forward_input_ids=[[1, 2, 3, 4]],
|
50
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
app.run(main)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored Gemma2 model."""
|
17
|
+
|
18
|
+
from absl import app
|
19
|
+
from absl import flags
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
21
|
+
from ai_edge_torch.generative.examples.gemma import verify_util
|
22
|
+
from ai_edge_torch.generative.utilities import verifier
|
23
|
+
import kagglehub
|
24
|
+
|
25
|
+
_PROMPTS = flags.DEFINE_multi_string(
|
26
|
+
"prompts",
|
27
|
+
"What is the meaning of life?",
|
28
|
+
"The input prompts to generate answers.",
|
29
|
+
)
|
30
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
31
|
+
"max_new_tokens",
|
32
|
+
30,
|
33
|
+
"The maximum size of the generated tokens.",
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def main(_):
|
38
|
+
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
|
39
|
+
|
40
|
+
verifier.log_msg("Building the reauthored model from", checkpoint)
|
41
|
+
reauthored_model = gemma2.build_2b_model(checkpoint)
|
42
|
+
|
43
|
+
verify_util.verify_reauthored_gemma_model(
|
44
|
+
checkpoint=checkpoint,
|
45
|
+
variant="2b-v2",
|
46
|
+
reauthored_model=reauthored_model,
|
47
|
+
generate_prompts=_PROMPTS.value,
|
48
|
+
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
49
|
+
max_new_tokens=_MAX_NEW_TOKENS.value,
|
50
|
+
atol=1e-04,
|
51
|
+
)
|
52
|
+
|
53
|
+
|
54
|
+
if __name__ == "__main__":
|
55
|
+
app.run(main)
|
@@ -0,0 +1,142 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Utility functions to verify the reauthored Gemma model."""
|
17
|
+
|
18
|
+
import dataclasses
|
19
|
+
import os
|
20
|
+
from typing import List, Tuple
|
21
|
+
|
22
|
+
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
|
+
from ai_edge_torch.generative.utilities import verifier
|
24
|
+
from gemma import config as gemma_config
|
25
|
+
from gemma import model as gemma_model
|
26
|
+
import torch
|
27
|
+
|
28
|
+
|
29
|
+
@dataclasses.dataclass
|
30
|
+
class _Output:
|
31
|
+
logits: torch.Tensor
|
32
|
+
|
33
|
+
|
34
|
+
class GemmaWrapper(verifier.ModelWrapper):
|
35
|
+
"""Gemma model wrapper for verification.
|
36
|
+
|
37
|
+
Verifier calls model.forward() with maxium sequence length (1024) expecting
|
38
|
+
the output has 'logits' field while Gemma gets the input tokens with the
|
39
|
+
actual length and returns logits in a tuple.
|
40
|
+
|
41
|
+
Verifier runs tokenizer before model.generate() while Gemma runs the tokenizer
|
42
|
+
inside model.generate().
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(self, model: torch.nn.Module, max_new_tokens: int):
|
46
|
+
super().__init__(model)
|
47
|
+
self.max_new_tokens = max_new_tokens
|
48
|
+
|
49
|
+
def _get_actual_input_len(self, tokens: torch.Tensor) -> int:
|
50
|
+
for i in range(tokens.shape[1]):
|
51
|
+
if tokens[0, i] == 0:
|
52
|
+
return i
|
53
|
+
return tokens.shape[1]
|
54
|
+
|
55
|
+
def _get_kv_caches(
|
56
|
+
self, max_seq_len: int
|
57
|
+
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
58
|
+
config = self.model.config
|
59
|
+
cache_size = (1, max_seq_len, config.num_key_value_heads, config.head_dim)
|
60
|
+
cache = torch.zeros(cache_size)
|
61
|
+
return [
|
62
|
+
(cache.clone(), cache.clone()) for _ in range(config.num_hidden_layers)
|
63
|
+
]
|
64
|
+
|
65
|
+
def forward(self, tokens: torch.Tensor) -> _Output:
|
66
|
+
"""Forwards the model after reducing input tokens to the actual length."""
|
67
|
+
actual_input_len = self._get_actual_input_len(tokens)
|
68
|
+
input_pos = torch.arange(0, actual_input_len, dtype=torch.long)
|
69
|
+
mask_cache = attn_utils.build_causal_mask_cache(tokens.shape[1])
|
70
|
+
_, logits = self.model.forward(
|
71
|
+
input_token_ids=tokens[0, :actual_input_len].unsqueeze(0),
|
72
|
+
input_positions=input_pos,
|
73
|
+
kv_write_indices=None,
|
74
|
+
kv_caches=self._get_kv_caches(tokens.shape[1]),
|
75
|
+
mask=mask_cache.index_select(2, input_pos),
|
76
|
+
output_positions=input_pos,
|
77
|
+
temperatures=None,
|
78
|
+
top_ps=torch.tensor([1.0], dtype=torch.float),
|
79
|
+
top_ks=torch.tensor([1], dtype=torch.long),
|
80
|
+
)
|
81
|
+
return _Output(logits.float())
|
82
|
+
|
83
|
+
def generate(self, tokens: torch.Tensor) -> torch.Tensor:
|
84
|
+
"""Generates the response after decoding the tokens into a string."""
|
85
|
+
prompts = self.model.tokenizer.decode(tokens[0].tolist())
|
86
|
+
response = self.model.generate(
|
87
|
+
prompts, device="cpu", output_len=self.max_new_tokens, top_k=1
|
88
|
+
)
|
89
|
+
return torch.tensor([self.model.tokenizer.encode(prompts + response)])
|
90
|
+
|
91
|
+
|
92
|
+
class TokenizerWrapper(torch.nn.Module):
|
93
|
+
"""Tokenizer wrapper for verification.
|
94
|
+
|
95
|
+
Verifier expects the tokenizer to handle tokens in torch.Tensor while Gemma
|
96
|
+
tokenizer expects tokens in a list.
|
97
|
+
"""
|
98
|
+
|
99
|
+
def __init__(self, tokenizer: torch.nn.Module):
|
100
|
+
super().__init__()
|
101
|
+
self.tokenizer = tokenizer
|
102
|
+
|
103
|
+
def encode(self, text: str, **_) -> torch.Tensor:
|
104
|
+
"""Adds one more dimension to the output of the tokenizer."""
|
105
|
+
return torch.tensor([self.tokenizer.encode(text)])
|
106
|
+
|
107
|
+
def decode(self, tokens: torch.Tensor) -> str:
|
108
|
+
"""Decodes the token sequence after converting to a list."""
|
109
|
+
return self.tokenizer.decode(tokens.tolist())
|
110
|
+
|
111
|
+
|
112
|
+
def verify_reauthored_gemma_model(
|
113
|
+
checkpoint: str,
|
114
|
+
variant: str,
|
115
|
+
reauthored_model: torch.nn.Module,
|
116
|
+
generate_prompts: List[str],
|
117
|
+
forward_input_ids: List[List[int]],
|
118
|
+
weight_filename: str = "model.ckpt",
|
119
|
+
tokenizer_filename: str = "tokenizer.model",
|
120
|
+
max_new_tokens: int = 20,
|
121
|
+
rtol: float = 1e-05,
|
122
|
+
atol: float = 1e-05,
|
123
|
+
):
|
124
|
+
"""Verifies the reauthored Gemma model against the original model."""
|
125
|
+
config = gemma_config.get_model_config(variant)
|
126
|
+
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
|
127
|
+
# Use float32 to be compatible with the reauthored model.
|
128
|
+
config.dtype = torch.float32
|
129
|
+
|
130
|
+
verifier.log_msg("Loading the original model from", checkpoint)
|
131
|
+
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
132
|
+
original_model.load_weights(os.path.join(checkpoint, weight_filename))
|
133
|
+
|
134
|
+
verifier.verify_reauthored_model(
|
135
|
+
original_model=GemmaWrapper(original_model, max_new_tokens),
|
136
|
+
reauthored_model=reauthored_model,
|
137
|
+
tokenizer=TokenizerWrapper(original_model.tokenizer),
|
138
|
+
generate_prompts=generate_prompts,
|
139
|
+
forward_input_ids=forward_input_ids,
|
140
|
+
rtol=rtol,
|
141
|
+
atol=atol,
|
142
|
+
)
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = openelm.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'openelm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -33,8 +33,10 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
33
33
|
def main(_):
|
34
34
|
checkpoint = "apple/OpenELM-3B"
|
35
35
|
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
-
|
37
|
-
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(
|
38
|
+
checkpoint, trust_remote_code=True
|
39
|
+
),
|
38
40
|
)
|
39
41
|
|
40
42
|
# Locate the cached dir.
|
@@ -50,10 +52,10 @@ def main(_):
|
|
50
52
|
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_checkpoint)
|
51
53
|
|
52
54
|
verifier.verify_reauthored_model(
|
53
|
-
original_model=
|
55
|
+
original_model=wrapper_model,
|
54
56
|
reauthored_model=reauthored_model,
|
55
57
|
tokenizer=tokenizer,
|
56
|
-
|
58
|
+
generate_prompts=_PROMPTS.value,
|
57
59
|
)
|
58
60
|
|
59
61
|
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = phi2.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'phi2_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -24,15 +24,25 @@ import transformers
|
|
24
24
|
|
25
25
|
_PROMPTS = flags.DEFINE_multi_string(
|
26
26
|
"prompts",
|
27
|
-
"
|
27
|
+
"Instruct: Write an email about the weather Output:",
|
28
28
|
"The input prompts to generate answers.",
|
29
29
|
)
|
30
30
|
|
31
|
+
_MAX_NEW_TOKENS = flags.DEFINE_integer(
|
32
|
+
"max_new_tokens",
|
33
|
+
30,
|
34
|
+
"The maximum size of the generated tokens.",
|
35
|
+
)
|
31
36
|
|
32
37
|
def main(_):
|
33
38
|
checkpoint = kagglehub.model_download("Microsoft/phi/transformers/2")
|
34
39
|
verifier.log_msg("Loading the original model from", checkpoint)
|
35
|
-
|
40
|
+
generation_config = transformers.GenerationConfig.from_pretrained(checkpoint)
|
41
|
+
generation_config.max_new_tokens = _MAX_NEW_TOKENS.value
|
42
|
+
wrapper_model = verifier.ModelWrapper(
|
43
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
44
|
+
hf_generation_config=generation_config,
|
45
|
+
)
|
36
46
|
|
37
47
|
verifier.log_msg("Building the reauthored model from", checkpoint)
|
38
48
|
reauthored_model = phi2.build_model(checkpoint)
|
@@ -41,10 +51,10 @@ def main(_):
|
|
41
51
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
42
52
|
|
43
53
|
verifier.verify_reauthored_model(
|
44
|
-
original_model=
|
54
|
+
original_model=wrapper_model,
|
45
55
|
reauthored_model=reauthored_model,
|
46
56
|
tokenizer=tokenizer,
|
47
|
-
|
57
|
+
generate_prompts=_PROMPTS.value,
|
48
58
|
atol=1e-03,
|
49
59
|
)
|
50
60
|
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = smollm.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'smollm_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -33,8 +33,9 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
33
33
|
def main(_):
|
34
34
|
checkpoint = "HuggingFaceTB/SmolLM-135M"
|
35
35
|
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
-
|
37
|
-
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(checkpoint),
|
38
|
+
)
|
38
39
|
# Locate the cached dir.
|
39
40
|
cached_config_file = transformers.utils.cached_file(
|
40
41
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -47,10 +48,10 @@ def main(_):
|
|
47
48
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
48
49
|
|
49
50
|
verifier.verify_reauthored_model(
|
50
|
-
original_model=
|
51
|
+
original_model=wrapper_model,
|
51
52
|
reauthored_model=reauthored_model,
|
52
53
|
tokenizer=tokenizer,
|
53
|
-
|
54
|
+
generate_prompts=_PROMPTS.value,
|
54
55
|
atol=1e-04,
|
55
56
|
)
|
56
57
|
|
@@ -30,17 +30,17 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
30
30
|
)
|
31
31
|
_TFLITE_PATH = flags.DEFINE_string(
|
32
32
|
'tflite_path',
|
33
|
-
'/tmp/
|
33
|
+
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
36
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
38
|
+
1024,
|
39
39
|
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
43
|
-
|
43
|
+
1280,
|
44
44
|
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
45
|
)
|
46
46
|
_QUANTIZE = flags.DEFINE_bool(
|
@@ -54,9 +54,11 @@ def main(_):
|
|
54
54
|
pytorch_model = tiny_llama.build_model(
|
55
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
56
|
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'tinyllama_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
57
59
|
converter.convert_to_tflite(
|
58
60
|
pytorch_model,
|
59
|
-
tflite_path=_TFLITE_PATH.value,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
60
62
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
63
|
quantize=_QUANTIZE.value,
|
62
64
|
)
|
@@ -33,10 +33,11 @@ _PROMPTS = flags.DEFINE_multi_string(
|
|
33
33
|
def main(_):
|
34
34
|
checkpoint = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
35
35
|
verifier.log_msg("Loading the original model from", checkpoint)
|
36
|
-
|
37
|
-
|
36
|
+
wrapper_model = verifier.ModelWrapper(
|
37
|
+
model=transformers.AutoModelForCausalLM.from_pretrained(
|
38
|
+
checkpoint, trust_remote_code=True
|
39
|
+
),
|
38
40
|
)
|
39
|
-
|
40
41
|
# Locate the cached dir.
|
41
42
|
cached_config_file = transformers.utils.cached_file(
|
42
43
|
checkpoint, transformers.utils.CONFIG_NAME
|
@@ -49,10 +50,10 @@ def main(_):
|
|
49
50
|
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
|
50
51
|
|
51
52
|
verifier.verify_reauthored_model(
|
52
|
-
original_model=
|
53
|
+
original_model=wrapper_model,
|
53
54
|
reauthored_model=reauthored_model,
|
54
55
|
tokenizer=tokenizer,
|
55
|
-
|
56
|
+
generate_prompts=_PROMPTS.value,
|
56
57
|
atol=1e-04,
|
57
58
|
)
|
58
59
|
|
@@ -14,7 +14,7 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import ai_edge_torch
|
17
|
-
from ai_edge_torch.generative.examples.gemma import
|
17
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
18
18
|
from ai_edge_torch.generative.quantize import quant_recipes
|
19
19
|
import numpy as np
|
20
20
|
import torch
|
@@ -22,8 +22,8 @@ import torch
|
|
22
22
|
|
23
23
|
def main():
|
24
24
|
# Build a PyTorch model as usual
|
25
|
-
config =
|
26
|
-
model =
|
25
|
+
config = gemma1.get_fake_model_config()
|
26
|
+
model = gemma1.Gemma(config)
|
27
27
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
28
28
|
tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
|
29
29
|
tokens[0, :4] = idx
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch import config as ai_edge_config
|
20
|
-
from ai_edge_torch.generative.examples.gemma import
|
20
|
+
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
21
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
22
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
@@ -28,7 +28,7 @@ import numpy as np
|
|
28
28
|
import torch
|
29
29
|
|
30
30
|
from absl.testing import absltest as googletest
|
31
|
-
from
|
31
|
+
from ai_edge_litert import interpreter
|
32
32
|
|
33
33
|
|
34
34
|
class TestModelConversion(googletest.TestCase):
|
@@ -82,9 +82,9 @@ class TestModelConversion(googletest.TestCase):
|
|
82
82
|
ai_edge_config.Config.use_torch_xla,
|
83
83
|
reason="tests with custom ops are not supported on oss",
|
84
84
|
)
|
85
|
-
def
|
86
|
-
config =
|
87
|
-
pytorch_model =
|
85
|
+
def test_gemma1(self):
|
86
|
+
config = gemma1.get_fake_model_config()
|
87
|
+
pytorch_model = gemma1.Gemma(config).eval()
|
88
88
|
self._test_model(
|
89
89
|
config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
|
90
90
|
)
|
@@ -16,17 +16,65 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import datetime
|
19
|
-
from typing import List
|
19
|
+
from typing import List, Optional, Union
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
-
import numpy as np
|
23
22
|
import torch
|
23
|
+
import transformers
|
24
24
|
|
25
25
|
|
26
26
|
def log_msg(*args):
|
27
27
|
print("[%s]" % datetime.datetime.now(), *args)
|
28
28
|
|
29
29
|
|
30
|
+
class ModelWrapper(torch.nn.Module):
|
31
|
+
"""A wrapper for the model to be verified, this could be a HuggingFace model
|
32
|
+
|
33
|
+
or a regular PyTorch model.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(
|
37
|
+
self,
|
38
|
+
model: torch.nn.Module,
|
39
|
+
model_format: str = "huggingface",
|
40
|
+
hf_generation_config: Optional[transformers.GenerationConfig] = None,
|
41
|
+
):
|
42
|
+
"""Initializes the wrapper.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
model (torch.nn.Module): The original model. This could be a model built
|
46
|
+
from HuggingFace transformers, or a regular PyTorch model.
|
47
|
+
model_format (str): The format of the model. It should be either
|
48
|
+
"huggingface" or "pytorch".
|
49
|
+
hf_generation_config (transformers.GenerationConfig): The HuggingFace
|
50
|
+
generation config. This config will only be used if the underlying model
|
51
|
+
is built from HuggingFace transformers.
|
52
|
+
"""
|
53
|
+
super().__init__()
|
54
|
+
self.model = model
|
55
|
+
self.model_format = model_format
|
56
|
+
self.hf_generation_config = hf_generation_config
|
57
|
+
|
58
|
+
def generate(
|
59
|
+
self, inputs: torch.Tensor
|
60
|
+
) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
|
61
|
+
if self.model_format == "huggingface":
|
62
|
+
return self.model.generate(
|
63
|
+
inputs=inputs, generation_config=self.hf_generation_config
|
64
|
+
)
|
65
|
+
else:
|
66
|
+
raise NotImplementedError(
|
67
|
+
"generate() is not implemented for model format: %s"
|
68
|
+
% self.model_format
|
69
|
+
)
|
70
|
+
|
71
|
+
def forward(
|
72
|
+
self,
|
73
|
+
inputs: torch.Tensor,
|
74
|
+
):
|
75
|
+
return self.model.forward(inputs)
|
76
|
+
|
77
|
+
|
30
78
|
def forward(
|
31
79
|
model: torch.nn.Module,
|
32
80
|
tokens: torch.Tensor,
|
@@ -75,9 +123,9 @@ def generate(
|
|
75
123
|
|
76
124
|
|
77
125
|
def verify_with_input_ids(
|
78
|
-
original_model:
|
126
|
+
original_model: ModelWrapper,
|
79
127
|
reauthored_model: torch.nn.Module,
|
80
|
-
input_ids:
|
128
|
+
input_ids: List[int],
|
81
129
|
kv_cache_max_len: int = 1024,
|
82
130
|
rtol: float = 1e-05,
|
83
131
|
atol: float = 1e-05,
|
@@ -87,10 +135,10 @@ def verify_with_input_ids(
|
|
87
135
|
It compares only one outputs from the original and the reauthored model.
|
88
136
|
|
89
137
|
Args:
|
90
|
-
original_model (
|
138
|
+
original_model (ModelWrapper): The original model.
|
91
139
|
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
92
140
|
Generative API.
|
93
|
-
input_ids (
|
141
|
+
input_ids (List[int]): The input token IDs to forward with.
|
94
142
|
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
95
143
|
rtol (float): The relative tolerance for the comparison.
|
96
144
|
atol (float): The absolute tolerance for the comparison.
|
@@ -99,18 +147,17 @@ def verify_with_input_ids(
|
|
99
147
|
True if the model reauthored generates the same output of the original.
|
100
148
|
"""
|
101
149
|
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
102
|
-
|
103
|
-
tokens[0, :input_ids_len] = input_ids
|
150
|
+
tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
|
104
151
|
|
105
152
|
log_msg("Forwarding the original model...")
|
106
153
|
outputs_original = original_model.forward(tokens)
|
107
|
-
logits_original = outputs_original.logits[0,
|
154
|
+
logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
|
108
155
|
log_msg("logits_original: ", logits_original)
|
109
156
|
|
110
157
|
log_msg("Forwarding the reauthored model...")
|
111
158
|
kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
|
112
159
|
outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
|
113
|
-
logits_reauthored = outputs_reauthored[0][0,
|
160
|
+
logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
|
114
161
|
log_msg("logits_reauthored:", logits_reauthored)
|
115
162
|
|
116
163
|
return torch.allclose(
|
@@ -119,7 +166,7 @@ def verify_with_input_ids(
|
|
119
166
|
|
120
167
|
|
121
168
|
def verify_model_with_prompts(
|
122
|
-
original_model:
|
169
|
+
original_model: ModelWrapper,
|
123
170
|
reauthored_model: torch.nn.Module,
|
124
171
|
tokenizer: torch.nn.Module,
|
125
172
|
prompts: str,
|
@@ -130,7 +177,7 @@ def verify_model_with_prompts(
|
|
130
177
|
original and the reauthored model.
|
131
178
|
|
132
179
|
Args:
|
133
|
-
original_model (
|
180
|
+
original_model (ModelWrapper): The original model.
|
134
181
|
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
135
182
|
Generative API.
|
136
183
|
tokenizer (torch.nn.Module): The tokenizer.
|
@@ -156,10 +203,11 @@ def verify_model_with_prompts(
|
|
156
203
|
|
157
204
|
|
158
205
|
def verify_reauthored_model(
|
159
|
-
original_model:
|
206
|
+
original_model: ModelWrapper,
|
160
207
|
reauthored_model: torch.nn.Module,
|
161
208
|
tokenizer: torch.nn.Module,
|
162
|
-
|
209
|
+
generate_prompts: List[str],
|
210
|
+
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
163
211
|
rtol: float = 1e-05,
|
164
212
|
atol: float = 1e-05,
|
165
213
|
):
|
@@ -174,26 +222,29 @@ def verify_reauthored_model(
|
|
174
222
|
It prints out "PASS" or "FAILED" to the console.
|
175
223
|
|
176
224
|
Args:
|
177
|
-
original_model (
|
225
|
+
original_model (ModelWrapper): The original model.
|
178
226
|
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
179
227
|
Generative API.
|
180
228
|
tokenizer (torch.nn.Module): The tokenizer.
|
181
|
-
|
229
|
+
generate_prompts (List[str]): List of the input prompts to generate answers.
|
230
|
+
forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
|
231
|
+
forward with.
|
182
232
|
rtol (float): The relative tolerance for the comparison.
|
183
233
|
atol (float): The absolute tolerance for the comparison.
|
184
234
|
"""
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
235
|
+
for input_ids in forward_input_ids:
|
236
|
+
log_msg("Verifying the reauthored model with input IDs:", input_ids)
|
237
|
+
if verify_with_input_ids(
|
238
|
+
original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
|
239
|
+
):
|
240
|
+
log_msg("PASS")
|
241
|
+
else:
|
242
|
+
log_msg("FAILED")
|
192
243
|
|
193
|
-
for
|
194
|
-
log_msg("Verifying the reauthored model with prompts:",
|
244
|
+
for prompts in generate_prompts:
|
245
|
+
log_msg("Verifying the reauthored model with prompts:", prompts)
|
195
246
|
if verify_model_with_prompts(
|
196
|
-
original_model, reauthored_model, tokenizer,
|
247
|
+
original_model, reauthored_model, tokenizer, prompts
|
197
248
|
):
|
198
249
|
log_msg("PASS")
|
199
250
|
else:
|
ai_edge_torch/model.py
CHANGED
@@ -27,6 +27,8 @@ from typing import Callable
|
|
27
27
|
import numpy.typing as npt
|
28
28
|
import tensorflow as tf
|
29
29
|
|
30
|
+
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import
|
31
|
+
|
30
32
|
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
31
33
|
|
32
34
|
|
@@ -65,7 +67,7 @@ class TfLiteModel(Model):
|
|
65
67
|
tflite_model: A TFlite serialized object.
|
66
68
|
"""
|
67
69
|
self._tflite_model = tflite_model
|
68
|
-
self._interpreter_builder = lambda:
|
70
|
+
self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
|
69
71
|
model_content=self._tflite_model,
|
70
72
|
experimental_default_delegate_latest_features=True,
|
71
73
|
)
|
@@ -75,12 +77,13 @@ class TfLiteModel(Model):
|
|
75
77
|
return self._tflite_model
|
76
78
|
|
77
79
|
def set_interpreter_builder(
|
78
|
-
self, builder: Callable[[],
|
80
|
+
self, builder: Callable[[], tfl_interpreter.Interpreter]
|
79
81
|
) -> None:
|
80
82
|
"""Sets a custom interpreter builder.
|
81
83
|
|
82
84
|
Args:
|
83
|
-
builder: A function that returns a `
|
85
|
+
builder: A function that returns a `tfl_interpreter.Interpreter` or its
|
86
|
+
subclass.
|
84
87
|
"""
|
85
88
|
self._interpreter_builder = builder
|
86
89
|
|
@@ -166,7 +169,7 @@ class TfLiteModel(Model):
|
|
166
169
|
|
167
170
|
# Check if this is indeed a tflite model:
|
168
171
|
try:
|
169
|
-
interpreter =
|
172
|
+
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
|
170
173
|
interpreter.get_signature_list()
|
171
174
|
except:
|
172
175
|
return None
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240921
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,8 +2,8 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
|
-
ai_edge_torch/model.py,sha256=
|
6
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
+
ai_edge_torch/version.py,sha256=t9zajdsiowClI2fG0RkKVonPF-SUx9UBuUDOEZFU9y4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -25,7 +25,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
25
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
27
27
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
28
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
28
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=40QRxQFNeSRr4dLXJkzG-wKUlvJtsfv62cdvRrmBv5w,15097
|
29
29
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
|
30
30
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
31
31
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
@@ -39,22 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
39
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/gemma/
|
43
|
-
ai_edge_torch/generative/examples/gemma/
|
44
|
-
ai_edge_torch/generative/examples/gemma/
|
45
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
44
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=KsM6VlzluTqbodG24IFr3biPxBrLay0z0gmnG0bcU2U,9277
|
46
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=B14IR4mIw6qBVUbiIRdfdUzHMCIJCJ0RFPsYOxA46qc,1776
|
47
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax1m4WTAnC3gMxBpcvGA7-xTO1Iuw,1802
|
48
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
|
46
49
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
48
51
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
|
49
|
-
ai_edge_torch/generative/examples/openelm/verify.py,sha256=
|
52
|
+
ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
|
50
53
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
51
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
54
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
52
55
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
53
|
-
ai_edge_torch/generative/examples/phi/verify.py,sha256=
|
56
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=QPYX6weEZGMEXt_Vb2hNARPAECQBKzx-KCivd4dzOrw,2145
|
54
57
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
55
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
58
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
56
59
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
57
|
-
ai_edge_torch/generative/examples/smollm/verify.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
|
58
61
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
59
62
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
60
63
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
@@ -78,16 +81,16 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
|
|
78
81
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
79
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
80
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
81
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
84
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
|
82
85
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
|
83
|
-
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=
|
86
|
+
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LUChL5tA7FHL_DlTg5QKvGInmH9AwVVw9a-omcndiz8,2095
|
84
87
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
85
88
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
86
89
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
87
90
|
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
88
91
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
89
92
|
ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
|
90
|
-
ai_edge_torch/generative/layers/feed_forward.py,sha256=
|
93
|
+
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
91
94
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
92
95
|
ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
|
93
96
|
ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
|
@@ -98,7 +101,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW
|
|
98
101
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
99
102
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
|
100
103
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
101
|
-
ai_edge_torch/generative/quantize/example.py,sha256=
|
104
|
+
ai_edge_torch/generative/quantize/example.py,sha256=tlACaRsz6lqOxakzpXVFJZYfFKOiFqetcYVJqWVRdPE,1542
|
102
105
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
103
106
|
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
|
104
107
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
@@ -107,8 +110,8 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
107
110
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
108
111
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
109
112
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
110
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
111
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
113
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
|
114
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mAK8Pm4mgGyilDSBtFazCRDetoqYKKB0sGC83MPKE0M,4494
|
112
115
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
113
116
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
114
117
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -116,7 +119,7 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
|
|
116
119
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
117
120
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
118
121
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
122
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=vU9KgmFS7I9jNS_3H2SWROx-rbNqtMKgQC2MRhdqQ4g,8803
|
120
123
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
121
124
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
122
125
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -163,8 +166,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
163
166
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
164
167
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
165
168
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
170
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/METADATA,sha256=SWy7BhOQDe0_SBF17deNndzt1bEYy7iXUxy0KznIPYM,1859
|
171
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
172
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
173
|
+
ai_edge_torch_nightly-0.3.0.dev20240921.dist-info/RECORD,,
|
File without changes
|
File without changes
|