ai-edge-torch-nightly 0.3.0.dev20240914__py3-none-any.whl → 0.3.0.dev20240918__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -0
- ai_edge_torch/generative/layers/unet/model_config.py +2 -0
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +6 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/RECORD +20 -19
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/gemma2_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_gemma2_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = gemma2.build_2b_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_gemma2_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/gemma_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_gemma_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Gemma 2B model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = gemma.build_2b_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_gemma_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/openelm_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_openelm_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts OpenELM model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = openelm.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_openelm_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/phi2_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_phi2_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Phi-2 model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = phi2.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_phi2_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/smollm_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_smollm_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts SmolLM model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = smollm.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_smollm_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
|
|
336
336
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
337
337
|
query_dim=output_channel,
|
338
338
|
cross_dim=config.transformer_cross_attention_dim,
|
339
|
+
hidden_dim=output_channel,
|
340
|
+
output_dim=output_channel,
|
339
341
|
attention_batch_size=config.transformer_batch_size,
|
340
342
|
normalization_config=config.transformer_norm_config,
|
341
343
|
attention_config=build_attention_config(
|
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
|
|
406
408
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
407
409
|
query_dim=mid_block_channels,
|
408
410
|
cross_dim=config.transformer_cross_attention_dim,
|
411
|
+
hidden_dim=mid_block_channels,
|
412
|
+
output_dim=mid_block_channels,
|
409
413
|
attention_batch_size=config.transformer_batch_size,
|
410
414
|
normalization_config=config.transformer_norm_config,
|
411
415
|
attention_config=build_attention_config(
|
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
|
|
477
481
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
478
482
|
query_dim=output_channel,
|
479
483
|
cross_dim=config.transformer_cross_attention_dim,
|
484
|
+
hidden_dim=output_channel,
|
485
|
+
output_dim=output_channel,
|
480
486
|
attention_batch_size=config.transformer_batch_size,
|
481
487
|
normalization_config=config.transformer_norm_config,
|
482
488
|
attention_config=build_attention_config(
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_tiny_llama_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts TinyLlama model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = tiny_llama.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_tiny_llama_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
|
|
298
298
|
batch_size: int,
|
299
299
|
query_dim: int,
|
300
300
|
cross_dim: int,
|
301
|
+
hidden_dim: int,
|
302
|
+
output_dim: int,
|
301
303
|
config: cfg.AttentionConfig,
|
302
304
|
enable_hlfb: bool,
|
303
305
|
):
|
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
|
|
307
309
|
batch_size (int): batch size of the input tensor.
|
308
310
|
query_dim (int): query tensor's dimension.
|
309
311
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
312
|
+
hidden_dim (int): hidden dimension that q, k, v tensors project to.
|
313
|
+
output_dim (int): output tensor's dimension.
|
310
314
|
config (cfg.AttentionConfig): attention specific configurations.
|
311
315
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
312
316
|
"""
|
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
|
|
314
318
|
self.config = config
|
315
319
|
self.n_heads = config.num_heads
|
316
320
|
self.q_projection = nn.Linear(
|
317
|
-
query_dim,
|
321
|
+
query_dim, hidden_dim, bias=config.qkv_use_bias
|
318
322
|
)
|
319
323
|
self.k_projection = nn.Linear(
|
320
|
-
cross_dim,
|
324
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
321
325
|
)
|
322
326
|
self.v_projection = nn.Linear(
|
323
|
-
cross_dim,
|
327
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
324
328
|
)
|
325
329
|
self.output_projection = nn.Linear(
|
326
|
-
|
330
|
+
hidden_dim, output_dim, bias=config.output_proj_use_bias
|
327
331
|
)
|
328
332
|
|
329
333
|
self.sdpa_func = (
|
@@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
|
|
68
68
|
class CrossAttentionBlock2DConfig:
|
69
69
|
query_dim: int
|
70
70
|
cross_dim: int
|
71
|
+
hidden_dim: int
|
72
|
+
output_dim: int
|
71
73
|
normalization_config: layers_cfg.NormalizationConfig
|
72
74
|
attention_config: layers_cfg.AttentionConfig
|
73
75
|
enable_hlfb: bool = True
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions for model conversion."""
|
17
|
+
|
18
|
+
import ai_edge_torch
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
21
|
+
import torch
|
22
|
+
|
23
|
+
|
24
|
+
def convert_to_tflite(
|
25
|
+
pytorch_model: torch.nn.Module,
|
26
|
+
tflite_path: str,
|
27
|
+
prefill_seq_len: int = 512,
|
28
|
+
quantize: bool = True,
|
29
|
+
):
|
30
|
+
"""Converts a nn.Module model to multi-signature tflite model.
|
31
|
+
|
32
|
+
A PyTorch model will be converted to a tflite model with two signatures:
|
33
|
+
"prefill" and "decode".
|
34
|
+
|
35
|
+
"prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
|
36
|
+
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
37
|
+
external KV cache as a sample input.
|
38
|
+
|
39
|
+
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
40
|
+
of shape [1, 1] of the token position, and an external KV cache as a sample
|
41
|
+
input.
|
42
|
+
|
43
|
+
The final tflite model will be exported to tflite_path.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
47
|
+
tflite_path (str): The tflite file path to export.
|
48
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
49
|
+
Defaults to 512.
|
50
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
51
|
+
to True.
|
52
|
+
"""
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
58
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
59
|
+
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
61
|
+
edge_model = (
|
62
|
+
ai_edge_torch.signature(
|
63
|
+
'prefill',
|
64
|
+
pytorch_model,
|
65
|
+
sample_kwargs={
|
66
|
+
'tokens': prefill_tokens,
|
67
|
+
'input_pos': prefill_input_pos,
|
68
|
+
'kv_cache': kv,
|
69
|
+
},
|
70
|
+
)
|
71
|
+
.signature(
|
72
|
+
'decode',
|
73
|
+
pytorch_model,
|
74
|
+
sample_kwargs={
|
75
|
+
'tokens': decode_token,
|
76
|
+
'input_pos': decode_input_pos,
|
77
|
+
'kv_cache': kv,
|
78
|
+
},
|
79
|
+
)
|
80
|
+
.convert(quant_config=quant_config)
|
81
|
+
)
|
82
|
+
edge_model.export(tflite_path)
|
@@ -811,6 +811,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
811
811
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
812
812
|
query_dim=output_channel,
|
813
813
|
cross_dim=config.transformer_cross_attention_dim,
|
814
|
+
hidden_dim=output_channel,
|
815
|
+
output_dim=output_channel,
|
814
816
|
normalization_config=config.transformer_norm_config,
|
815
817
|
attention_config=build_attention_config(
|
816
818
|
num_heads=config.transformer_num_attention_heads,
|
@@ -877,6 +879,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
877
879
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
878
880
|
query_dim=mid_block_channels,
|
879
881
|
cross_dim=config.transformer_cross_attention_dim,
|
882
|
+
hidden_dim=mid_block_channels,
|
883
|
+
output_dim=mid_block_channels,
|
880
884
|
normalization_config=config.transformer_norm_config,
|
881
885
|
attention_config=build_attention_config(
|
882
886
|
num_heads=config.transformer_num_attention_heads,
|
@@ -950,6 +954,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
950
954
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
951
955
|
query_dim=output_channel,
|
952
956
|
cross_dim=config.transformer_cross_attention_dim,
|
957
|
+
hidden_dim=output_channel,
|
958
|
+
output_dim=output_channel,
|
953
959
|
normalization_config=config.transformer_norm_config,
|
954
960
|
attention_config=build_attention_config(
|
955
961
|
num_heads=config.transformer_num_attention_heads,
|
@@ -212,17 +212,25 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
212
212
|
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
213
|
@lower(torch.ops.aten.slice_scatter)
|
214
214
|
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
-
start = start
|
216
|
-
end = end
|
215
|
+
start = start if start is not None else 0
|
216
|
+
end = end if end is not None else self.type.shape[dim]
|
217
|
+
|
218
|
+
start, end = np.clip(
|
219
|
+
[start, end], -self.type.shape[dim], self.type.shape[dim]
|
220
|
+
)
|
221
|
+
|
217
222
|
if start < 0:
|
218
223
|
start = self.type.shape[dim] + start
|
219
224
|
if end < 0:
|
220
225
|
end = self.type.shape[dim] + end
|
221
226
|
|
222
|
-
end
|
227
|
+
if end <= start or np.prod(src.type.shape) == 0:
|
228
|
+
return self
|
223
229
|
|
230
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
224
231
|
padding_low = start
|
225
232
|
padding_high = self.type.shape[dim] - end
|
233
|
+
interior_padding = step - 1
|
226
234
|
|
227
235
|
rank = len(self.type.shape)
|
228
236
|
src = stablehlo.pad(
|
@@ -230,7 +238,9 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
230
238
|
utils.splat(0, src.type.element_type, []),
|
231
239
|
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
240
|
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
-
interior_padding=[
|
241
|
+
interior_padding=[
|
242
|
+
interior_padding if i == dim else 0 for i in range(rank)
|
243
|
+
],
|
234
244
|
)
|
235
245
|
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
246
|
pred[*[
|
@@ -57,6 +57,7 @@ global_registry.decompositions.update(
|
|
57
57
|
torch._decomp.get_decompositions([
|
58
58
|
torch.ops.aten.upsample_nearest2d,
|
59
59
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
60
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
60
61
|
torch.ops.aten._adaptive_avg_pool2d,
|
61
62
|
torch.ops.aten._adaptive_avg_pool3d,
|
62
63
|
torch.ops.aten.grid_sampler_2d,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240918
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=jWg5qA8V0XqgFoqjk0SCsNWPRBeTmfrir9u0bucHYOU,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,25 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
39
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
43
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
|
44
44
|
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
|
45
45
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
|
46
46
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
47
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
|
48
48
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
|
49
49
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
|
51
51
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
|
52
52
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
53
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
|
54
54
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
|
55
55
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
56
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
57
57
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
58
58
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
59
59
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
|
61
61
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
63
63
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -75,12 +75,12 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
|
|
75
75
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
76
76
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
77
77
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
78
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
|
79
79
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
|
80
80
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
81
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
82
82
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
84
84
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
85
85
|
ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
|
86
86
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
@@ -90,9 +90,9 @@ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvC
|
|
90
90
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
91
91
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
92
92
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
93
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
|
94
94
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
95
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
|
96
96
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
97
|
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
98
98
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -108,8 +108,9 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5
|
|
108
108
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
109
109
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
110
110
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
111
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
|
111
112
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
112
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=
|
113
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
|
113
114
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
114
115
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
115
116
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
@@ -141,13 +142,13 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
141
142
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
142
143
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
143
144
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
144
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
145
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VvB050UCjB17h6-UNtsaqzVF13MGI01fPFkdmmghTj4,8790
|
145
146
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
146
147
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
147
148
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
|
148
149
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
149
150
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
150
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
151
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
|
151
152
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
152
153
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
153
154
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -157,8 +158,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
157
158
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
158
159
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
159
160
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
160
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
-
ai_edge_torch_nightly-0.3.0.
|
162
|
-
ai_edge_torch_nightly-0.3.0.
|
163
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/METADATA,sha256=dMaIr8Iny84IfNGQGSrtlTGkYlH_mAMmgvGWm5-pkxM,1859
|
163
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/RECORD,,
|
File without changes
|
File without changes
|