ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +41 -8
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +7 -2
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +24 -22
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma1.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma2.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'llama',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
58
68
|
|
59
69
|
_BUILDER = {
|
60
70
|
'1b': llama.build_1b_model,
|
@@ -66,13 +76,13 @@ def main(_):
|
|
66
76
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
67
77
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
68
78
|
)
|
69
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
71
79
|
converter.convert_to_tflite(
|
72
80
|
pytorch_model,
|
73
|
-
|
81
|
+
output_path=_OUTPUT_PATH.value,
|
82
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
74
83
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
75
84
|
quantize=_QUANTIZE.value,
|
85
|
+
lora_ranks=_LORA_RANKS.value,
|
76
86
|
export_config=ExportConfig(),
|
77
87
|
)
|
78
88
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'openelm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = openelm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
|
-
|
63
68
|
converter.convert_to_tflite(
|
64
69
|
pytorch_model,
|
65
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
66
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
67
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
68
75
|
export_config=ExportConfig(),
|
69
76
|
)
|
70
77
|
|
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
40
40
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
41
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
42
42
|
)
|
43
|
-
|
44
|
-
'
|
43
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
44
|
+
'output_path',
|
45
45
|
'/tmp/',
|
46
|
-
'The
|
46
|
+
'The path to export the tflite model.',
|
47
|
+
)
|
48
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
49
|
+
'output_name_prefix',
|
50
|
+
'paligemma',
|
51
|
+
'The prefix of the output tflite model name.',
|
47
52
|
)
|
48
53
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
49
54
|
'prefill_seq_len',
|
@@ -73,11 +78,11 @@ def main(_):
|
|
73
78
|
version=int(_VERSION.value),
|
74
79
|
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
75
80
|
)
|
76
|
-
|
77
|
-
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
81
|
+
|
78
82
|
converter.convert_to_tflite(
|
79
83
|
pytorch_model,
|
80
|
-
|
84
|
+
output_path=_OUTPUT_PATH.value,
|
85
|
+
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
|
81
86
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
87
|
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
83
88
|
quantize=_QUANTIZE.value,
|
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi3',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi3.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi2.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'qwen',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
68
|
+
|
58
69
|
|
59
70
|
_BUILDER = {
|
60
71
|
'0.5b': qwen.build_0_5b_model,
|
@@ -67,16 +78,13 @@ def main(_):
|
|
67
78
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
68
79
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
69
80
|
)
|
70
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
71
|
-
model_size = _MODEL_SIZE.value.replace('.', '_')
|
72
|
-
output_filename = (
|
73
|
-
f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
74
|
-
)
|
75
81
|
converter.convert_to_tflite(
|
76
82
|
pytorch_model,
|
77
|
-
|
83
|
+
output_path=_OUTPUT_PATH.value,
|
84
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
78
85
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
79
86
|
quantize=_QUANTIZE.value,
|
87
|
+
lora_ranks=_LORA_RANKS.value,
|
80
88
|
export_config=ExportConfig(),
|
81
89
|
)
|
82
90
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'smollm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = smollm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
|
59
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
60
|
-
output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
68
|
converter.convert_to_tflite(
|
62
69
|
pytorch_model,
|
63
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
64
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
65
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
66
75
|
export_config=ExportConfig(),
|
67
76
|
)
|
68
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'tinyllama',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = tiny_llama.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
68
|
converter.convert_to_tflite(
|
63
69
|
pytorch_model,
|
64
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
65
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
66
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
67
75
|
export_config=ExportConfig(),
|
68
76
|
)
|
69
77
|
|
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import builder
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
23
|
from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
|
23
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
25
|
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
@@ -93,6 +94,7 @@ class TransformerBlock(nn.Module):
|
|
93
94
|
mask: Optional[torch.Tensor] = None,
|
94
95
|
input_pos: Optional[torch.Tensor] = None,
|
95
96
|
kv_cache: kv_utils.KVCacheEntry = None,
|
97
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
96
98
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
97
99
|
"""Forward function of the TransformerBlock.
|
98
100
|
|
@@ -102,6 +104,7 @@ class TransformerBlock(nn.Module):
|
|
102
104
|
mask (torch.Tensor): the optional mask tensor.
|
103
105
|
input_pos (torch.Tensor): the optional input position tensor.
|
104
106
|
kv_cache (KVCacheEntry): the optional kv cache entry.
|
107
|
+
lora (LoRAEntry): the optional lora entry.
|
105
108
|
|
106
109
|
Returns:
|
107
110
|
output activation from this transformer block, and updated kv cache (if
|
@@ -110,7 +113,9 @@ class TransformerBlock(nn.Module):
|
|
110
113
|
kv = None
|
111
114
|
if self.config.parallel_residual:
|
112
115
|
x_norm = self.pre_atten_norm(x)
|
113
|
-
atten_func_out = self.atten_func(
|
116
|
+
atten_func_out = self.atten_func(
|
117
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
118
|
+
)
|
114
119
|
if kv_cache is None:
|
115
120
|
attn_out = atten_func_out
|
116
121
|
else:
|
@@ -119,7 +124,9 @@ class TransformerBlock(nn.Module):
|
|
119
124
|
output = x + attn_out + ff_out
|
120
125
|
else:
|
121
126
|
x_norm = self.pre_atten_norm(x)
|
122
|
-
atten_func_out = self.atten_func(
|
127
|
+
atten_func_out = self.atten_func(
|
128
|
+
x_norm, rope, mask, input_pos, kv_cache, lora
|
129
|
+
)
|
123
130
|
if kv_cache is None:
|
124
131
|
attn_out = atten_func_out
|
125
132
|
else:
|
@@ -179,6 +186,7 @@ class CausalSelfAttention(nn.Module):
|
|
179
186
|
mask: Optional[torch.Tensor] = None,
|
180
187
|
input_pos: Optional[torch.Tensor] = None,
|
181
188
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
189
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
182
190
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
183
191
|
"""Forward function of the CausalSelfAttention layer, which can support
|
184
192
|
|
@@ -189,7 +197,8 @@ class CausalSelfAttention(nn.Module):
|
|
189
197
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
190
198
|
mask (torch.Tensor): the optional mask tensor.
|
191
199
|
input_pos (torch.Tensor): the optional input position tensor.
|
192
|
-
kv_cache (KVCacheEntry):
|
200
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
201
|
+
lora (LoRAEntry): the optional lora entry.
|
193
202
|
|
194
203
|
Returns:
|
195
204
|
output activation from this self attention layer, and the updated
|
@@ -228,6 +237,11 @@ class CausalSelfAttention(nn.Module):
|
|
228
237
|
dim=-1,
|
229
238
|
)
|
230
239
|
|
240
|
+
if lora is not None:
|
241
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
242
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
243
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
244
|
+
|
231
245
|
q = self.query_norm(q)
|
232
246
|
k = self.key_norm(k)
|
233
247
|
|
@@ -244,7 +258,7 @@ class CausalSelfAttention(nn.Module):
|
|
244
258
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
245
259
|
k, v = kv_cache.k_cache, kv_cache.v_cache
|
246
260
|
|
247
|
-
|
261
|
+
sdpa_out = self.sdpa_func(
|
248
262
|
q,
|
249
263
|
k,
|
250
264
|
v,
|
@@ -252,10 +266,13 @@ class CausalSelfAttention(nn.Module):
|
|
252
266
|
mask=mask,
|
253
267
|
softcap=self.config.logit_softcap,
|
254
268
|
)
|
255
|
-
|
269
|
+
sdpa_out = sdpa_out.reshape(B, T, -1)
|
256
270
|
|
257
271
|
# Compute the output projection.
|
258
|
-
y = self.output_projection(
|
272
|
+
y = self.output_projection(sdpa_out)
|
273
|
+
if lora is not None:
|
274
|
+
y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
|
275
|
+
|
259
276
|
return y if kv_cache is None else (y, kv_cache)
|
260
277
|
|
261
278
|
|
@@ -268,6 +285,7 @@ class SelfAttention(CausalSelfAttention):
|
|
268
285
|
rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
269
286
|
input_pos: Optional[torch.Tensor] = None,
|
270
287
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
288
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
271
289
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
|
272
290
|
"""Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
|
273
291
|
|
@@ -275,18 +293,23 @@ class SelfAttention(CausalSelfAttention):
|
|
275
293
|
x (torch.Tensor): the input tensor.
|
276
294
|
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
277
295
|
input_pos (torch.Tensor): the optional input position tensor.
|
278
|
-
kv_cache (KVCacheEntry):
|
296
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
297
|
+
lora (LoRAEntry): the optional lora entry.
|
279
298
|
|
280
299
|
Returns:
|
281
300
|
output activation from this self attention layer, and the updated
|
282
301
|
KV Cach Entry (if passed in).
|
283
302
|
"""
|
284
303
|
B, T, _ = x.size()
|
304
|
+
assert (
|
305
|
+
kv_cache is None
|
306
|
+
), "KV cache is not supported in non-causal SelfAttention."
|
285
307
|
return super().forward(
|
286
308
|
x,
|
287
309
|
rope=rope,
|
288
310
|
mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
|
289
311
|
input_pos=input_pos,
|
312
|
+
lora=lora,
|
290
313
|
)
|
291
314
|
|
292
315
|
|
@@ -343,6 +366,7 @@ class CrossAttention(nn.Module):
|
|
343
366
|
mask: Optional[torch.Tensor] = None,
|
344
367
|
input_pos: Optional[torch.Tensor] = None,
|
345
368
|
kv_cache: Optional[kv_utils.KVCacheEntry] = None,
|
369
|
+
lora: Optional[lora_utils.LoRAEntry] = None,
|
346
370
|
):
|
347
371
|
"""Forward function of the CrossAttention layer.
|
348
372
|
|
@@ -353,7 +377,8 @@ class CrossAttention(nn.Module):
|
|
353
377
|
mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
|
354
378
|
[B, n_heads, target_seq_len, source_seq_len].
|
355
379
|
input_pos (torch.Tensor): the optional input position tensor.
|
356
|
-
kv_cache (KVCacheEntry):
|
380
|
+
kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
|
381
|
+
lora (LoRAEntry): the optional lora entry.
|
357
382
|
|
358
383
|
Returns:
|
359
384
|
output activation from this cross attention layer.
|
@@ -366,6 +391,11 @@ class CrossAttention(nn.Module):
|
|
366
391
|
k = self.k_projection(y)
|
367
392
|
v = self.v_projection(y)
|
368
393
|
|
394
|
+
if lora is not None:
|
395
|
+
q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
|
396
|
+
k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
|
397
|
+
v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
|
398
|
+
|
369
399
|
interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
|
370
400
|
q = q.view(interim_shape)
|
371
401
|
k = k.view(interim_shape)
|
@@ -388,4 +418,7 @@ class CrossAttention(nn.Module):
|
|
388
418
|
|
389
419
|
# Compute the output projection.
|
390
420
|
y = self.output_projection(y)
|
421
|
+
if lora is not None:
|
422
|
+
y += lora_utils.apply_lora(y, lora.attention.output)
|
423
|
+
|
391
424
|
return y if kv_cache is None else (y, kv_cache)
|