ai-edge-torch-nightly 0.3.0.dev20240914__py3-none-any.whl → 0.3.0.dev20240918__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -0
- ai_edge_torch/generative/layers/unet/model_config.py +2 -0
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +6 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/RECORD +20 -19
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240914.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/gemma2_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_gemma2_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Gemma2 2B model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = gemma2.build_2b_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_gemma2_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.gemma import gemma
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/gemma_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_gemma_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Gemma 2B model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = gemma.build_2b_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_gemma_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.openelm import openelm
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/openelm_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_openelm_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts OpenELM model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = openelm.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_openelm_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.phi import phi2
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/phi2_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_phi2_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts a Phi-2 model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = phi2.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_phi2_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.smollm import smollm
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/smollm_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_smollm_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts SmolLM model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = smollm.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_smollm_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
|
|
336
336
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
337
337
|
query_dim=output_channel,
|
338
338
|
cross_dim=config.transformer_cross_attention_dim,
|
339
|
+
hidden_dim=output_channel,
|
340
|
+
output_dim=output_channel,
|
339
341
|
attention_batch_size=config.transformer_batch_size,
|
340
342
|
normalization_config=config.transformer_norm_config,
|
341
343
|
attention_config=build_attention_config(
|
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
|
|
406
408
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
407
409
|
query_dim=mid_block_channels,
|
408
410
|
cross_dim=config.transformer_cross_attention_dim,
|
411
|
+
hidden_dim=mid_block_channels,
|
412
|
+
output_dim=mid_block_channels,
|
409
413
|
attention_batch_size=config.transformer_batch_size,
|
410
414
|
normalization_config=config.transformer_norm_config,
|
411
415
|
attention_config=build_attention_config(
|
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
|
|
477
481
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
478
482
|
query_dim=output_channel,
|
479
483
|
cross_dim=config.transformer_cross_attention_dim,
|
484
|
+
hidden_dim=output_channel,
|
485
|
+
output_dim=output_channel,
|
480
486
|
attention_batch_size=config.transformer_batch_size,
|
481
487
|
normalization_config=config.transformer_norm_config,
|
482
488
|
attention_config=build_attention_config(
|
@@ -18,69 +18,49 @@
|
|
18
18
|
import os
|
19
19
|
import pathlib
|
20
20
|
|
21
|
-
import
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
22
23
|
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
|
23
|
-
from ai_edge_torch.generative.
|
24
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
25
|
-
import torch
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
26
25
|
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
|
+
'prefill_seq_len',
|
38
|
+
512,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
27
51
|
|
28
|
-
def convert_tiny_llama_to_tflite(
|
29
|
-
checkpoint_path: str,
|
30
|
-
prefill_seq_len: int = 512,
|
31
|
-
kv_cache_max_len: int = 1024,
|
32
|
-
quantize: bool = True,
|
33
|
-
):
|
34
|
-
"""Converts TinyLlama model to multi-signature tflite model.
|
35
52
|
|
36
|
-
|
37
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
38
|
-
holding the checkpoint.
|
39
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
40
|
-
Defaults to 512.
|
41
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
42
|
-
including both prefill and decode. Defaults to 1024.
|
43
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
44
|
-
to True.
|
45
|
-
"""
|
53
|
+
def main(_):
|
46
54
|
pytorch_model = tiny_llama.build_model(
|
47
|
-
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
48
56
|
)
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
55
|
-
|
56
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
57
|
-
edge_model = (
|
58
|
-
ai_edge_torch.signature(
|
59
|
-
'prefill',
|
60
|
-
pytorch_model,
|
61
|
-
sample_kwargs={
|
62
|
-
'tokens': prefill_tokens,
|
63
|
-
'input_pos': prefill_input_pos,
|
64
|
-
'kv_cache': kv,
|
65
|
-
},
|
66
|
-
)
|
67
|
-
.signature(
|
68
|
-
'decode',
|
69
|
-
pytorch_model,
|
70
|
-
sample_kwargs={
|
71
|
-
'tokens': decode_token,
|
72
|
-
'input_pos': decode_input_pos,
|
73
|
-
'kv_cache': kv,
|
74
|
-
},
|
75
|
-
)
|
76
|
-
.convert(quant_config=quant_config)
|
77
|
-
)
|
78
|
-
quant_suffix = 'q8' if quantize else 'f32'
|
79
|
-
edge_model.export(
|
80
|
-
f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
57
|
+
converter.convert_to_tflite(
|
58
|
+
pytorch_model,
|
59
|
+
tflite_path=_TFLITE_PATH.value,
|
60
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
61
|
+
quantize=_QUANTIZE.value,
|
81
62
|
)
|
82
63
|
|
83
64
|
|
84
65
|
if __name__ == '__main__':
|
85
|
-
|
86
|
-
convert_tiny_llama_to_tflite(path)
|
66
|
+
app.run(main)
|
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
|
|
298
298
|
batch_size: int,
|
299
299
|
query_dim: int,
|
300
300
|
cross_dim: int,
|
301
|
+
hidden_dim: int,
|
302
|
+
output_dim: int,
|
301
303
|
config: cfg.AttentionConfig,
|
302
304
|
enable_hlfb: bool,
|
303
305
|
):
|
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
|
|
307
309
|
batch_size (int): batch size of the input tensor.
|
308
310
|
query_dim (int): query tensor's dimension.
|
309
311
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
312
|
+
hidden_dim (int): hidden dimension that q, k, v tensors project to.
|
313
|
+
output_dim (int): output tensor's dimension.
|
310
314
|
config (cfg.AttentionConfig): attention specific configurations.
|
311
315
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
312
316
|
"""
|
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
|
|
314
318
|
self.config = config
|
315
319
|
self.n_heads = config.num_heads
|
316
320
|
self.q_projection = nn.Linear(
|
317
|
-
query_dim,
|
321
|
+
query_dim, hidden_dim, bias=config.qkv_use_bias
|
318
322
|
)
|
319
323
|
self.k_projection = nn.Linear(
|
320
|
-
cross_dim,
|
324
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
321
325
|
)
|
322
326
|
self.v_projection = nn.Linear(
|
323
|
-
cross_dim,
|
327
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
324
328
|
)
|
325
329
|
self.output_projection = nn.Linear(
|
326
|
-
|
330
|
+
hidden_dim, output_dim, bias=config.output_proj_use_bias
|
327
331
|
)
|
328
332
|
|
329
333
|
self.sdpa_func = (
|
@@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
|
|
68
68
|
class CrossAttentionBlock2DConfig:
|
69
69
|
query_dim: int
|
70
70
|
cross_dim: int
|
71
|
+
hidden_dim: int
|
72
|
+
output_dim: int
|
71
73
|
normalization_config: layers_cfg.NormalizationConfig
|
72
74
|
attention_config: layers_cfg.AttentionConfig
|
73
75
|
enable_hlfb: bool = True
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions for model conversion."""
|
17
|
+
|
18
|
+
import ai_edge_torch
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
21
|
+
import torch
|
22
|
+
|
23
|
+
|
24
|
+
def convert_to_tflite(
|
25
|
+
pytorch_model: torch.nn.Module,
|
26
|
+
tflite_path: str,
|
27
|
+
prefill_seq_len: int = 512,
|
28
|
+
quantize: bool = True,
|
29
|
+
):
|
30
|
+
"""Converts a nn.Module model to multi-signature tflite model.
|
31
|
+
|
32
|
+
A PyTorch model will be converted to a tflite model with two signatures:
|
33
|
+
"prefill" and "decode".
|
34
|
+
|
35
|
+
"prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
|
36
|
+
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
37
|
+
external KV cache as a sample input.
|
38
|
+
|
39
|
+
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
40
|
+
of shape [1, 1] of the token position, and an external KV cache as a sample
|
41
|
+
input.
|
42
|
+
|
43
|
+
The final tflite model will be exported to tflite_path.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
47
|
+
tflite_path (str): The tflite file path to export.
|
48
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
49
|
+
Defaults to 512.
|
50
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
51
|
+
to True.
|
52
|
+
"""
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
58
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
59
|
+
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
61
|
+
edge_model = (
|
62
|
+
ai_edge_torch.signature(
|
63
|
+
'prefill',
|
64
|
+
pytorch_model,
|
65
|
+
sample_kwargs={
|
66
|
+
'tokens': prefill_tokens,
|
67
|
+
'input_pos': prefill_input_pos,
|
68
|
+
'kv_cache': kv,
|
69
|
+
},
|
70
|
+
)
|
71
|
+
.signature(
|
72
|
+
'decode',
|
73
|
+
pytorch_model,
|
74
|
+
sample_kwargs={
|
75
|
+
'tokens': decode_token,
|
76
|
+
'input_pos': decode_input_pos,
|
77
|
+
'kv_cache': kv,
|
78
|
+
},
|
79
|
+
)
|
80
|
+
.convert(quant_config=quant_config)
|
81
|
+
)
|
82
|
+
edge_model.export(tflite_path)
|
@@ -811,6 +811,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
811
811
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
812
812
|
query_dim=output_channel,
|
813
813
|
cross_dim=config.transformer_cross_attention_dim,
|
814
|
+
hidden_dim=output_channel,
|
815
|
+
output_dim=output_channel,
|
814
816
|
normalization_config=config.transformer_norm_config,
|
815
817
|
attention_config=build_attention_config(
|
816
818
|
num_heads=config.transformer_num_attention_heads,
|
@@ -877,6 +879,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
877
879
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
878
880
|
query_dim=mid_block_channels,
|
879
881
|
cross_dim=config.transformer_cross_attention_dim,
|
882
|
+
hidden_dim=mid_block_channels,
|
883
|
+
output_dim=mid_block_channels,
|
880
884
|
normalization_config=config.transformer_norm_config,
|
881
885
|
attention_config=build_attention_config(
|
882
886
|
num_heads=config.transformer_num_attention_heads,
|
@@ -950,6 +954,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
950
954
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
951
955
|
query_dim=output_channel,
|
952
956
|
cross_dim=config.transformer_cross_attention_dim,
|
957
|
+
hidden_dim=output_channel,
|
958
|
+
output_dim=output_channel,
|
953
959
|
normalization_config=config.transformer_norm_config,
|
954
960
|
attention_config=build_attention_config(
|
955
961
|
num_heads=config.transformer_num_attention_heads,
|
@@ -212,17 +212,25 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
212
212
|
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
213
|
@lower(torch.ops.aten.slice_scatter)
|
214
214
|
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
-
start = start
|
216
|
-
end = end
|
215
|
+
start = start if start is not None else 0
|
216
|
+
end = end if end is not None else self.type.shape[dim]
|
217
|
+
|
218
|
+
start, end = np.clip(
|
219
|
+
[start, end], -self.type.shape[dim], self.type.shape[dim]
|
220
|
+
)
|
221
|
+
|
217
222
|
if start < 0:
|
218
223
|
start = self.type.shape[dim] + start
|
219
224
|
if end < 0:
|
220
225
|
end = self.type.shape[dim] + end
|
221
226
|
|
222
|
-
end
|
227
|
+
if end <= start or np.prod(src.type.shape) == 0:
|
228
|
+
return self
|
223
229
|
|
230
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
224
231
|
padding_low = start
|
225
232
|
padding_high = self.type.shape[dim] - end
|
233
|
+
interior_padding = step - 1
|
226
234
|
|
227
235
|
rank = len(self.type.shape)
|
228
236
|
src = stablehlo.pad(
|
@@ -230,7 +238,9 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
230
238
|
utils.splat(0, src.type.element_type, []),
|
231
239
|
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
240
|
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
-
interior_padding=[
|
241
|
+
interior_padding=[
|
242
|
+
interior_padding if i == dim else 0 for i in range(rank)
|
243
|
+
],
|
234
244
|
)
|
235
245
|
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
246
|
pred[*[
|
@@ -57,6 +57,7 @@ global_registry.decompositions.update(
|
|
57
57
|
torch._decomp.get_decompositions([
|
58
58
|
torch.ops.aten.upsample_nearest2d,
|
59
59
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
60
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
60
61
|
torch.ops.aten._adaptive_avg_pool2d,
|
61
62
|
torch.ops.aten._adaptive_avg_pool3d,
|
62
63
|
torch.ops.aten.grid_sampler_2d,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240918
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=jWg5qA8V0XqgFoqjk0SCsNWPRBeTmfrir9u0bucHYOU,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,25 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
39
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
43
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
|
44
44
|
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
|
45
45
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
|
46
46
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
47
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
|
48
48
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
|
49
49
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
|
51
51
|
ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
|
52
52
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
53
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
|
54
54
|
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
|
55
55
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
56
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
57
57
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
58
58
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
59
59
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
|
61
61
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
63
63
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -75,12 +75,12 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
|
|
75
75
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
76
76
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
77
77
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
78
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
|
79
79
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
|
80
80
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
81
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
82
82
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
84
84
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
85
85
|
ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
|
86
86
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
@@ -90,9 +90,9 @@ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvC
|
|
90
90
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
91
91
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
92
92
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
93
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
|
94
94
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
95
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
|
96
96
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
97
|
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
98
98
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -108,8 +108,9 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5
|
|
108
108
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
109
109
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
110
110
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
111
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
|
111
112
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
112
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=
|
113
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
|
113
114
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
114
115
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
115
116
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
@@ -141,13 +142,13 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
141
142
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
142
143
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
143
144
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
144
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
145
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VvB050UCjB17h6-UNtsaqzVF13MGI01fPFkdmmghTj4,8790
|
145
146
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
146
147
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
147
148
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
|
148
149
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
149
150
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
150
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
151
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
|
151
152
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
152
153
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
153
154
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -157,8 +158,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
157
158
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
158
159
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
159
160
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
160
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
-
ai_edge_torch_nightly-0.3.0.
|
162
|
-
ai_edge_torch_nightly-0.3.0.
|
163
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/METADATA,sha256=dMaIr8Iny84IfNGQGSrtlTGkYlH_mAMmgvGWm5-pkxM,1859
|
163
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/RECORD,,
|
File without changes
|
File without changes
|