ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250108__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  4. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  5. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  6. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  7. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  8. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  9. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  11. ai_edge_torch/generative/layers/attention.py +41 -8
  12. ai_edge_torch/generative/layers/lora.py +557 -0
  13. ai_edge_torch/generative/test/test_lora.py +147 -0
  14. ai_edge_torch/generative/utilities/converter.py +100 -47
  15. ai_edge_torch/generative/utilities/model_builder.py +7 -2
  16. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  17. ai_edge_torch/odml_torch/export.py +6 -2
  18. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  19. ai_edge_torch/version.py +1 -1
  20. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/METADATA +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/RECORD +24 -22
  22. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/LICENSE +0 -0
  23. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/WHEEL +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250108.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma1.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma2.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'llama',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
58
68
 
59
69
  _BUILDER = {
60
70
  '1b': llama.build_1b_model,
@@ -66,13 +76,13 @@ def main(_):
66
76
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
77
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
78
  )
69
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
79
  converter.convert_to_tflite(
72
80
  pytorch_model,
73
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
81
+ output_path=_OUTPUT_PATH.value,
82
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
74
83
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
75
84
  quantize=_QUANTIZE.value,
85
+ lora_ranks=_LORA_RANKS.value,
76
86
  export_config=ExportConfig(),
77
87
  )
78
88
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'openelm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = openelm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
-
63
68
  converter.convert_to_tflite(
64
69
  pytorch_model,
65
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
66
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
67
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
68
75
  export_config=ExportConfig(),
69
76
  )
70
77
 
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
40
40
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
41
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
42
42
  )
43
- _TFLITE_PATH = flags.DEFINE_string(
44
- 'tflite_path',
43
+ _OUTPUT_PATH = flags.DEFINE_string(
44
+ 'output_path',
45
45
  '/tmp/',
46
- 'The tflite file path to export.',
46
+ 'The path to export the tflite model.',
47
+ )
48
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
49
+ 'output_name_prefix',
50
+ 'paligemma',
51
+ 'The prefix of the output tflite model name.',
47
52
  )
48
53
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
49
54
  'prefill_seq_len',
@@ -73,11 +78,11 @@ def main(_):
73
78
  version=int(_VERSION.value),
74
79
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
75
80
  )
76
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
77
- output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
81
+
78
82
  converter.convert_to_tflite(
79
83
  pytorch_model,
80
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
84
+ output_path=_OUTPUT_PATH.value,
85
+ output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
81
86
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
87
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
83
88
  quantize=_QUANTIZE.value,
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi3',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi3.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi2.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/qwen'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'qwen',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,12 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
68
+
58
69
 
59
70
  _BUILDER = {
60
71
  '0.5b': qwen.build_0_5b_model,
@@ -67,16 +78,13 @@ def main(_):
67
78
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
68
79
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
69
80
  )
70
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
71
- model_size = _MODEL_SIZE.value.replace('.', '_')
72
- output_filename = (
73
- f'qwen_{model_size}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
74
- )
75
81
  converter.convert_to_tflite(
76
82
  pytorch_model,
77
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
83
+ output_path=_OUTPUT_PATH.value,
84
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
78
85
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
79
86
  quantize=_QUANTIZE.value,
87
+ lora_ranks=_LORA_RANKS.value,
80
88
  export_config=ExportConfig(),
81
89
  )
82
90
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'smollm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,20 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = smollm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
-
59
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
60
- output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
68
  converter.convert_to_tflite(
62
69
  pytorch_model,
63
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
64
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
66
75
  export_config=ExportConfig(),
67
76
  )
68
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'tinyllama',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,21 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = tiny_llama.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'tinyllama_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
68
  converter.convert_to_tflite(
63
69
  pytorch_model,
64
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
65
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
67
75
  export_config=ExportConfig(),
68
76
  )
69
77
 
@@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union
19
19
 
20
20
  from ai_edge_torch.generative.layers import builder
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.layers import lora as lora_utils
22
23
  from ai_edge_torch.generative.layers import scaled_dot_product_attention as sdpa
23
24
  import ai_edge_torch.generative.layers.model_config as cfg
24
25
  import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
@@ -93,6 +94,7 @@ class TransformerBlock(nn.Module):
93
94
  mask: Optional[torch.Tensor] = None,
94
95
  input_pos: Optional[torch.Tensor] = None,
95
96
  kv_cache: kv_utils.KVCacheEntry = None,
97
+ lora: Optional[lora_utils.LoRAEntry] = None,
96
98
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
97
99
  """Forward function of the TransformerBlock.
98
100
 
@@ -102,6 +104,7 @@ class TransformerBlock(nn.Module):
102
104
  mask (torch.Tensor): the optional mask tensor.
103
105
  input_pos (torch.Tensor): the optional input position tensor.
104
106
  kv_cache (KVCacheEntry): the optional kv cache entry.
107
+ lora (LoRAEntry): the optional lora entry.
105
108
 
106
109
  Returns:
107
110
  output activation from this transformer block, and updated kv cache (if
@@ -110,7 +113,9 @@ class TransformerBlock(nn.Module):
110
113
  kv = None
111
114
  if self.config.parallel_residual:
112
115
  x_norm = self.pre_atten_norm(x)
113
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
116
+ atten_func_out = self.atten_func(
117
+ x_norm, rope, mask, input_pos, kv_cache, lora
118
+ )
114
119
  if kv_cache is None:
115
120
  attn_out = atten_func_out
116
121
  else:
@@ -119,7 +124,9 @@ class TransformerBlock(nn.Module):
119
124
  output = x + attn_out + ff_out
120
125
  else:
121
126
  x_norm = self.pre_atten_norm(x)
122
- atten_func_out = self.atten_func(x_norm, rope, mask, input_pos, kv_cache)
127
+ atten_func_out = self.atten_func(
128
+ x_norm, rope, mask, input_pos, kv_cache, lora
129
+ )
123
130
  if kv_cache is None:
124
131
  attn_out = atten_func_out
125
132
  else:
@@ -179,6 +186,7 @@ class CausalSelfAttention(nn.Module):
179
186
  mask: Optional[torch.Tensor] = None,
180
187
  input_pos: Optional[torch.Tensor] = None,
181
188
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
189
+ lora: Optional[lora_utils.LoRAEntry] = None,
182
190
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
183
191
  """Forward function of the CausalSelfAttention layer, which can support
184
192
 
@@ -189,7 +197,8 @@ class CausalSelfAttention(nn.Module):
189
197
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
190
198
  mask (torch.Tensor): the optional mask tensor.
191
199
  input_pos (torch.Tensor): the optional input position tensor.
192
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
200
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
201
+ lora (LoRAEntry): the optional lora entry.
193
202
 
194
203
  Returns:
195
204
  output activation from this self attention layer, and the updated
@@ -228,6 +237,11 @@ class CausalSelfAttention(nn.Module):
228
237
  dim=-1,
229
238
  )
230
239
 
240
+ if lora is not None:
241
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
242
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
243
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
244
+
231
245
  q = self.query_norm(q)
232
246
  k = self.key_norm(k)
233
247
 
@@ -244,7 +258,7 @@ class CausalSelfAttention(nn.Module):
244
258
  kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
245
259
  k, v = kv_cache.k_cache, kv_cache.v_cache
246
260
 
247
- y = self.sdpa_func(
261
+ sdpa_out = self.sdpa_func(
248
262
  q,
249
263
  k,
250
264
  v,
@@ -252,10 +266,13 @@ class CausalSelfAttention(nn.Module):
252
266
  mask=mask,
253
267
  softcap=self.config.logit_softcap,
254
268
  )
255
- y = y.reshape(B, T, -1)
269
+ sdpa_out = sdpa_out.reshape(B, T, -1)
256
270
 
257
271
  # Compute the output projection.
258
- y = self.output_projection(y)
272
+ y = self.output_projection(sdpa_out)
273
+ if lora is not None:
274
+ y += lora_utils.apply_lora(sdpa_out, lora.attention.output)
275
+
259
276
  return y if kv_cache is None else (y, kv_cache)
260
277
 
261
278
 
@@ -268,6 +285,7 @@ class SelfAttention(CausalSelfAttention):
268
285
  rope: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
269
286
  input_pos: Optional[torch.Tensor] = None,
270
287
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
288
+ lora: Optional[lora_utils.LoRAEntry] = None,
271
289
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, kv_utils.KVCacheEntry]]:
272
290
  """Forward function of the SelfAttention layer, which can support MQA, GQA and MHA.
273
291
 
@@ -275,18 +293,23 @@ class SelfAttention(CausalSelfAttention):
275
293
  x (torch.Tensor): the input tensor.
276
294
  rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
277
295
  input_pos (torch.Tensor): the optional input position tensor.
278
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
296
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
297
+ lora (LoRAEntry): the optional lora entry.
279
298
 
280
299
  Returns:
281
300
  output activation from this self attention layer, and the updated
282
301
  KV Cach Entry (if passed in).
283
302
  """
284
303
  B, T, _ = x.size()
304
+ assert (
305
+ kv_cache is None
306
+ ), "KV cache is not supported in non-causal SelfAttention."
285
307
  return super().forward(
286
308
  x,
287
309
  rope=rope,
288
310
  mask=torch.zeros((B, 1, T, T), dtype=torch.float32),
289
311
  input_pos=input_pos,
312
+ lora=lora,
290
313
  )
291
314
 
292
315
 
@@ -343,6 +366,7 @@ class CrossAttention(nn.Module):
343
366
  mask: Optional[torch.Tensor] = None,
344
367
  input_pos: Optional[torch.Tensor] = None,
345
368
  kv_cache: Optional[kv_utils.KVCacheEntry] = None,
369
+ lora: Optional[lora_utils.LoRAEntry] = None,
346
370
  ):
347
371
  """Forward function of the CrossAttention layer.
348
372
 
@@ -353,7 +377,8 @@ class CrossAttention(nn.Module):
353
377
  mask (torch.Tensor): the optional mask tensor can be broadcaseted to shape
354
378
  [B, n_heads, target_seq_len, source_seq_len].
355
379
  input_pos (torch.Tensor): the optional input position tensor.
356
- kv_cache (KVCacheEntry): The KV cache entry corresponding to this module.
380
+ kv_cache (KVCacheEntry): the KV cache entry corresponding to this module.
381
+ lora (LoRAEntry): the optional lora entry.
357
382
 
358
383
  Returns:
359
384
  output activation from this cross attention layer.
@@ -366,6 +391,11 @@ class CrossAttention(nn.Module):
366
391
  k = self.k_projection(y)
367
392
  v = self.v_projection(y)
368
393
 
394
+ if lora is not None:
395
+ q += lora_utils.apply_lora(x, lora.attention.query, shape=q.shape)
396
+ k += lora_utils.apply_lora(x, lora.attention.key, shape=k.shape)
397
+ v += lora_utils.apply_lora(x, lora.attention.value, shape=v.shape)
398
+
369
399
  interim_shape = (batch_size, -1, self.n_heads, self.config.head_dim)
370
400
  q = q.view(interim_shape)
371
401
  k = k.view(interim_shape)
@@ -388,4 +418,7 @@ class CrossAttention(nn.Module):
388
418
 
389
419
  # Compute the output projection.
390
420
  y = self.output_projection(y)
421
+ if lora is not None:
422
+ y += lora_utils.apply_lora(y, lora.attention.output)
423
+
391
424
  return y if kv_cache is None else (y, kv_cache)