ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240918__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (20) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
  3. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
  4. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
  5. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
  6. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
  7. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
  8. ai_edge_torch/generative/layers/attention.py +8 -4
  9. ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -0
  10. ai_edge_torch/generative/layers/unet/model_config.py +2 -0
  11. ai_edge_torch/generative/utilities/converter.py +82 -0
  12. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +6 -0
  13. ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
  14. ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
  15. ai_edge_torch/version.py +1 -1
  16. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/RECORD +20 -19
  18. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/LICENSE +0 -0
  19. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/WHEEL +0 -0
  20. {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240918.dist-info}/top_level.txt +0 -0
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma2
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma2 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma2.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b')
86
- convert_gemma2_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.gemma import gemma
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/gemma_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_gemma_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Gemma 2B model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = gemma.build_2b_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/gemma_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b')
86
- convert_gemma_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.openelm import openelm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/openelm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_openelm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts OpenELM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = openelm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/openelm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm')
86
- convert_openelm_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.phi import phi2
23
- from ai_edge_torch.generative.layers import kv_cache
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/phi2_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_phi2_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts a Phi-2 model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = phi2.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_cache.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/phi2_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2')
86
- convert_phi2_to_tflite(path)
66
+ app.run(main)
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.smollm import smollm
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/smollm_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_smollm_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts SmolLM model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = smollm.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/smollm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smollm')
86
- convert_smollm_to_tflite(path)
66
+ app.run(main)
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
336
336
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
337
337
  query_dim=output_channel,
338
338
  cross_dim=config.transformer_cross_attention_dim,
339
+ hidden_dim=output_channel,
340
+ output_dim=output_channel,
339
341
  attention_batch_size=config.transformer_batch_size,
340
342
  normalization_config=config.transformer_norm_config,
341
343
  attention_config=build_attention_config(
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
406
408
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
407
409
  query_dim=mid_block_channels,
408
410
  cross_dim=config.transformer_cross_attention_dim,
411
+ hidden_dim=mid_block_channels,
412
+ output_dim=mid_block_channels,
409
413
  attention_batch_size=config.transformer_batch_size,
410
414
  normalization_config=config.transformer_norm_config,
411
415
  attention_config=build_attention_config(
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
477
481
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
478
482
  query_dim=output_channel,
479
483
  cross_dim=config.transformer_cross_attention_dim,
484
+ hidden_dim=output_channel,
485
+ output_dim=output_channel,
480
486
  attention_batch_size=config.transformer_batch_size,
481
487
  normalization_config=config.transformer_norm_config,
482
488
  attention_config=build_attention_config(
@@ -18,69 +18,49 @@
18
18
  import os
19
19
  import pathlib
20
20
 
21
- import ai_edge_torch
21
+ from absl import app
22
+ from absl import flags
22
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
- from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
- from ai_edge_torch.generative.quantize import quant_recipes
25
- import torch
24
+ from ai_edge_torch.generative.utilities import converter
26
25
 
26
+ _CHECKPOINT_PATH = flags.DEFINE_string(
27
+ 'checkpoint_path',
28
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama'),
29
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
30
+ )
31
+ _TFLITE_PATH = flags.DEFINE_string(
32
+ 'tflite_path',
33
+ '/tmp/tiny_llama_q8_seq512_ekv1024.tflite',
34
+ 'The tflite file path to export.',
35
+ )
36
+ _PREFILL_SEQ_LEN = flags.DEFINE_integer(
37
+ 'prefill_seq_len',
38
+ 512,
39
+ 'The maximum size of prefill input tensor.',
40
+ )
41
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
42
+ 'kv_cache_max_len',
43
+ 1024,
44
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
45
+ )
46
+ _QUANTIZE = flags.DEFINE_bool(
47
+ 'quantize',
48
+ True,
49
+ 'Whether the model should be quantized.',
50
+ )
27
51
 
28
- def convert_tiny_llama_to_tflite(
29
- checkpoint_path: str,
30
- prefill_seq_len: int = 512,
31
- kv_cache_max_len: int = 1024,
32
- quantize: bool = True,
33
- ):
34
- """Converts TinyLlama model to multi-signature tflite model.
35
52
 
36
- Args:
37
- checkpoint_path (str): The filepath to the model checkpoint, or directory
38
- holding the checkpoint.
39
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
- Defaults to 512.
41
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
- including both prefill and decode. Defaults to 1024.
43
- quantize (bool, optional): Whether the model should be quanized. Defaults
44
- to True.
45
- """
53
+ def main(_):
46
54
  pytorch_model = tiny_llama.build_model(
47
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
55
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
48
56
  )
49
- # Tensors used to trace the model graph during conversion.
50
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
51
- prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
52
- decode_token = torch.tensor([[0]], dtype=torch.int)
53
- decode_input_pos = torch.tensor([0], dtype=torch.int)
54
- kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
-
56
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
- edge_model = (
58
- ai_edge_torch.signature(
59
- 'prefill',
60
- pytorch_model,
61
- sample_kwargs={
62
- 'tokens': prefill_tokens,
63
- 'input_pos': prefill_input_pos,
64
- 'kv_cache': kv,
65
- },
66
- )
67
- .signature(
68
- 'decode',
69
- pytorch_model,
70
- sample_kwargs={
71
- 'tokens': decode_token,
72
- 'input_pos': decode_input_pos,
73
- 'kv_cache': kv,
74
- },
75
- )
76
- .convert(quant_config=quant_config)
77
- )
78
- quant_suffix = 'q8' if quantize else 'f32'
79
- edge_model.export(
80
- f'/tmp/tiny_llama_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
57
+ converter.convert_to_tflite(
58
+ pytorch_model,
59
+ tflite_path=_TFLITE_PATH.value,
60
+ prefill_seq_len=_PREFILL_SEQ_LEN.value,
61
+ quantize=_QUANTIZE.value,
81
62
  )
82
63
 
83
64
 
84
65
  if __name__ == '__main__':
85
- path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/tiny_llama')
86
- convert_tiny_llama_to_tflite(path)
66
+ app.run(main)
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
298
298
  batch_size: int,
299
299
  query_dim: int,
300
300
  cross_dim: int,
301
+ hidden_dim: int,
302
+ output_dim: int,
301
303
  config: cfg.AttentionConfig,
302
304
  enable_hlfb: bool,
303
305
  ):
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
307
309
  batch_size (int): batch size of the input tensor.
308
310
  query_dim (int): query tensor's dimension.
309
311
  cross_dim (int): cross attention's dimensions, for key and value tensors.
312
+ hidden_dim (int): hidden dimension that q, k, v tensors project to.
313
+ output_dim (int): output tensor's dimension.
310
314
  config (cfg.AttentionConfig): attention specific configurations.
311
315
  enable_hlfb (bool): whether hlfb is enabled or not.
312
316
  """
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
314
318
  self.config = config
315
319
  self.n_heads = config.num_heads
316
320
  self.q_projection = nn.Linear(
317
- query_dim, query_dim, bias=config.qkv_use_bias
321
+ query_dim, hidden_dim, bias=config.qkv_use_bias
318
322
  )
319
323
  self.k_projection = nn.Linear(
320
- cross_dim, query_dim, bias=config.qkv_use_bias
324
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
321
325
  )
322
326
  self.v_projection = nn.Linear(
323
- cross_dim, query_dim, bias=config.qkv_use_bias
327
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
324
328
  )
325
329
  self.output_projection = nn.Linear(
326
- query_dim, query_dim, bias=config.output_proj_use_bias
330
+ hidden_dim, output_dim, bias=config.output_proj_use_bias
327
331
  )
328
332
 
329
333
  self.sdpa_func = (
@@ -178,6 +178,8 @@ class CrossAttentionBlock2D(nn.Module):
178
178
  config.attention_batch_size,
179
179
  config.query_dim,
180
180
  config.cross_dim,
181
+ config.hidden_dim,
182
+ config.output_dim,
181
183
  config.attention_config,
182
184
  enable_hlfb=config.enable_hlfb,
183
185
  )
@@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
68
68
  class CrossAttentionBlock2DConfig:
69
69
  query_dim: int
70
70
  cross_dim: int
71
+ hidden_dim: int
72
+ output_dim: int
71
73
  normalization_config: layers_cfg.NormalizationConfig
72
74
  attention_config: layers_cfg.AttentionConfig
73
75
  enable_hlfb: bool = True
@@ -0,0 +1,82 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Common utility functions for model conversion."""
17
+
18
+ import ai_edge_torch
19
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
20
+ from ai_edge_torch.generative.quantize import quant_recipes
21
+ import torch
22
+
23
+
24
+ def convert_to_tflite(
25
+ pytorch_model: torch.nn.Module,
26
+ tflite_path: str,
27
+ prefill_seq_len: int = 512,
28
+ quantize: bool = True,
29
+ ):
30
+ """Converts a nn.Module model to multi-signature tflite model.
31
+
32
+ A PyTorch model will be converted to a tflite model with two signatures:
33
+ "prefill" and "decode".
34
+
35
+ "prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
36
+ sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
37
+ external KV cache as a sample input.
38
+
39
+ "decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
40
+ of shape [1, 1] of the token position, and an external KV cache as a sample
41
+ input.
42
+
43
+ The final tflite model will be exported to tflite_path.
44
+
45
+ Args:
46
+ pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
47
+ tflite_path (str): The tflite file path to export.
48
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
49
+ Defaults to 512.
50
+ quantize (bool, optional): Whether the model should be quanized. Defaults
51
+ to True.
52
+ """
53
+ # Tensors used to trace the model graph during conversion.
54
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
55
+ prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
56
+ decode_token = torch.tensor([[0]], dtype=torch.int)
57
+ decode_input_pos = torch.tensor([0], dtype=torch.int)
58
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
59
+
60
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
61
+ edge_model = (
62
+ ai_edge_torch.signature(
63
+ 'prefill',
64
+ pytorch_model,
65
+ sample_kwargs={
66
+ 'tokens': prefill_tokens,
67
+ 'input_pos': prefill_input_pos,
68
+ 'kv_cache': kv,
69
+ },
70
+ )
71
+ .signature(
72
+ 'decode',
73
+ pytorch_model,
74
+ sample_kwargs={
75
+ 'tokens': decode_token,
76
+ 'input_pos': decode_input_pos,
77
+ 'kv_cache': kv,
78
+ },
79
+ )
80
+ .convert(quant_config=quant_config)
81
+ )
82
+ edge_model.export(tflite_path)
@@ -811,6 +811,8 @@ class DiffusionModelLoader(BaseLoader):
811
811
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
812
812
  query_dim=output_channel,
813
813
  cross_dim=config.transformer_cross_attention_dim,
814
+ hidden_dim=output_channel,
815
+ output_dim=output_channel,
814
816
  normalization_config=config.transformer_norm_config,
815
817
  attention_config=build_attention_config(
816
818
  num_heads=config.transformer_num_attention_heads,
@@ -877,6 +879,8 @@ class DiffusionModelLoader(BaseLoader):
877
879
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
878
880
  query_dim=mid_block_channels,
879
881
  cross_dim=config.transformer_cross_attention_dim,
882
+ hidden_dim=mid_block_channels,
883
+ output_dim=mid_block_channels,
880
884
  normalization_config=config.transformer_norm_config,
881
885
  attention_config=build_attention_config(
882
886
  num_heads=config.transformer_num_attention_heads,
@@ -950,6 +954,8 @@ class DiffusionModelLoader(BaseLoader):
950
954
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
951
955
  query_dim=output_channel,
952
956
  cross_dim=config.transformer_cross_attention_dim,
957
+ hidden_dim=output_channel,
958
+ output_dim=output_channel,
953
959
  normalization_config=config.transformer_norm_config,
954
960
  attention_config=build_attention_config(
955
961
  num_heads=config.transformer_num_attention_heads,
@@ -212,17 +212,25 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
212
212
  # - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
213
213
  @lower(torch.ops.aten.slice_scatter)
214
214
  def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
215
- start = start or 0
216
- end = end or self.type.shape[dim]
215
+ start = start if start is not None else 0
216
+ end = end if end is not None else self.type.shape[dim]
217
+
218
+ start, end = np.clip(
219
+ [start, end], -self.type.shape[dim], self.type.shape[dim]
220
+ )
221
+
217
222
  if start < 0:
218
223
  start = self.type.shape[dim] + start
219
224
  if end < 0:
220
225
  end = self.type.shape[dim] + end
221
226
 
222
- end = start + step * math.ceil((end - start) / step) - (step - 1)
227
+ if end <= start or np.prod(src.type.shape) == 0:
228
+ return self
223
229
 
230
+ end = start + step * math.ceil((end - start) / step) - (step - 1)
224
231
  padding_low = start
225
232
  padding_high = self.type.shape[dim] - end
233
+ interior_padding = step - 1
226
234
 
227
235
  rank = len(self.type.shape)
228
236
  src = stablehlo.pad(
@@ -230,7 +238,9 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
230
238
  utils.splat(0, src.type.element_type, []),
231
239
  edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
232
240
  edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
233
- interior_padding=[step - 1 if i == dim else 0 for i in range(rank)],
241
+ interior_padding=[
242
+ interior_padding if i == dim else 0 for i in range(rank)
243
+ ],
234
244
  )
235
245
  pred = np.ones(self.type.shape, dtype=np.bool_)
236
246
  pred[*[
@@ -57,6 +57,7 @@ global_registry.decompositions.update(
57
57
  torch._decomp.get_decompositions([
58
58
  torch.ops.aten.upsample_nearest2d,
59
59
  torch.ops.aten._native_batch_norm_legit.no_stats,
60
+ torch.ops.aten._native_batch_norm_legit_functional,
60
61
  torch.ops.aten._adaptive_avg_pool2d,
61
62
  torch.ops.aten._adaptive_avg_pool3d,
62
63
  torch.ops.aten.grid_sampler_2d,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240916"
16
+ __version__ = "0.3.0.dev20240918"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240916
3
+ Version: 0.3.0.dev20240918
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=nRUErTd6i3Pxfpnp3BacFfEH5cQbDvxrA6YeTzKNOxU,706
6
+ ai_edge_torch/version.py,sha256=jWg5qA8V0XqgFoqjk0SCsNWPRBeTmfrir9u0bucHYOU,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -39,25 +39,25 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
39
39
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=bN_dtqi5C_dHpLsvXJ9vCb9OnZ0frLeyYoWBXZYJEqA,3061
43
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=fiFKkEe3TgOdpLnzsCZzIdwvEz0ikxDavQcRGQhlkBY,3053
42
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
43
+ ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
44
44
  ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
45
45
  ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
46
46
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=DgBuR1uq4YQWfWiENBxrx7UCVr4Jc5kWCyoi6ii5DTE,3058
47
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
48
48
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=kQTJlCDz_DHLRLlVWE0JEpbOjIGAKtxH1fTSc-jn1nU,8498
49
49
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=_tP5ArL0FKiBNoOqN2rG351IzmhNKQmWUfewlcSdKDs,3024
50
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
51
51
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=mGyBI-nORoI-LhZkI4MFAonkUflIX9iimAer_K8jpck,7088
52
52
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=66APmBId5UayZ7SWSO1zxcLiM8TucOMA-fFEHhm61qs,3049
53
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
54
54
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=_nK2DAOiSuxv5o8ip0i-gmhvvjwF5e7Dm3m5VTcsR2M,4276
55
55
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
56
56
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
57
57
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
58
58
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
59
59
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
60
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
60
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
61
61
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
62
62
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
63
63
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -75,12 +75,12 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
75
75
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
76
76
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
77
77
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5u6aOiCVahHNCgax5k9a8uhJn9eMzLa19ldscFKNyWo,3083
78
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
79
79
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkdczTI-Lr-1GDg0R2g4SUUGEMTUZ5uY,7023
80
80
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
81
81
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
82
82
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/layers/attention.py,sha256=37Fua94dQSiBA9Y5XvHxGb5IfN8p8UgNgu5YwM1Rmrw,13057
83
+ ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
84
84
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
85
85
  ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
86
86
  ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
@@ -90,9 +90,9 @@ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvC
90
90
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
91
91
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
92
92
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
93
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
94
94
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
95
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
95
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
96
96
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
97
  ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
98
98
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -108,8 +108,9 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=TD7dELN5cVw5
108
108
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
109
109
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
110
110
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
111
+ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
111
112
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
112
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
113
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
113
114
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
114
115
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
115
116
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
@@ -141,13 +142,13 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
141
142
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
142
143
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
143
144
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
144
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=E5j_xHuyDmA9fcgoi6p04zLGV9mFleyXzx6jSBi2wD0,8529
145
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VvB050UCjB17h6-UNtsaqzVF13MGI01fPFkdmmghTj4,8790
145
146
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
146
147
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
147
148
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
148
149
  ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
149
150
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
150
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
151
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
151
152
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
152
153
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
153
154
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -157,8 +158,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
157
158
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
158
159
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
159
160
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
160
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
161
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/METADATA,sha256=yK-gW8Z98p5-9PvIsfCu3f5FAACNAPH5_BecOImrfKo,1859
162
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
163
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
164
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/RECORD,,
161
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
162
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/METADATA,sha256=dMaIr8Iny84IfNGQGSrtlTGkYlH_mAMmgvGWm5-pkxM,1859
163
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
164
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
165
+ ai_edge_torch_nightly-0.3.0.dev20240918.dist-info/RECORD,,